Parcourir la source

Add more testing for do_list and do_filter

oz123 il y a 10 ans
Parent
commit
9e271f1d9a
2 fichiers modifiés avec 21 ajouts et 8 suppressions
  1. 9 2
      pwman/tests/test_base_ui.py
  2. 12 6
      pwman/ui/baseui.py

+ 9 - 2
pwman/tests/test_base_ui.py

@@ -52,7 +52,7 @@ class TestBaseUI(unittest.TestCase):
         self.assertListEqual(['foo', 'bar', 'baz'], tags)
         sys.stdin = sys.__stdin__
 
-    def test_do_new(self):
+    def test_1_do_new(self):
         sys.stdin = StringIO(("alice\nsecret\nexample.com\nsome notes"
                               "\nfoo bar baz"))
         _node = self.tester.cli.do_new('')
@@ -70,8 +70,15 @@ class TestBaseUI(unittest.TestCase):
         for idx, t in enumerate(['foo', 'bar', 'baz']):
             self.assertTrue(t, tags[idx])
 
-    def test_do_list(self):
+    def test_2_do_list(self):
+        self.output = StringIO()
+        self.saved_stdout = sys.stdout
+        sys.stdout = self.output
         self.tester.cli.do_list('')
+        self.tester.cli.do_list('foo')
+        self.tester.cli.do_list('bar')
+        sys.stdout = self.saved_stdout
+        self.output.getvalue()
 
 if __name__ == '__main__':
 

+ 12 - 6
pwman/ui/baseui.py

@@ -92,10 +92,8 @@ class BaseCommands(HelpUI):
         tags = [tn for tn in tagstrings]
         return tags
 
-    def _prep_term(self, args):
+    def _prep_term(self):
         self.do_cls('')
-        # TODO: fix do_filter!
-        #self.do_filter(args)
         if sys.platform != 'win32':
             rows, cols = tools.gettermsize()
         else:
@@ -117,11 +115,19 @@ class BaseCommands(HelpUI):
         formatted_entry = tools.typeset(fmt, Fore.YELLOW, False)
         print(formatted_entry)
 
+    def _get_node_ids(self, args):
+        filter = None
+        if args:
+            filter = args.split()[0]
+            ce = CryptoEngine.get()
+            filter = ce.encrypt(filter)
+        nodeids = self._db.listnodes(filter=filter)
+        return nodeids
+
     def do_list(self, args):
         """list all existing nodes in database"""
-        rows, cols = self._prep_term(args)
-
-        nodeids = self._db.listnodes()
+        rows, cols = self._prep_term()
+        nodeids = self._get_node_ids(args)
         nodes = self._db.getnodes(nodeids)
         _nodes_inst = []
         # user, pass, url, notes