ソースを参照

Merge branch 'master' of https://github.com/pwman3/pwman3

Oz N Tiram 9 年 前
コミット
64246b56b4

+ 1 - 0
.gitignore

@@ -13,3 +13,4 @@ MANIFEST
 htmlcov/*
 htmlcov/*
 .ropeproject*
 .ropeproject*
 secret.txt
 secret.txt
+.tox

+ 5 - 0
.travis.yml

@@ -3,6 +3,9 @@ python:
   - 2.7
   - 2.7
   - 3.4 
   - 3.4 
 
 
+services:
+  - mongodb
+
 before_script:
 before_script:
   - psql -c "CREATE USER tester WITH PASSWORD '123456';" -U postgres
   - psql -c "CREATE USER tester WITH PASSWORD '123456';" -U postgres
   - psql -c 'create database pwman;' -U postgres
   - psql -c 'create database pwman;' -U postgres
@@ -10,6 +13,7 @@ before_script:
   - mysql -e 'create database pwmantest' -uroot 
   - mysql -e 'create database pwmantest' -uroot 
   - mysql -e "create user 'pwman'@'localhost' IDENTIFIED BY '123456'" -uroot
   - mysql -e "create user 'pwman'@'localhost' IDENTIFIED BY '123456'" -uroot
   - mysql -e "grant all on pwmantest.* to 'pwman'@'localhost';" -uroot
   - mysql -e "grant all on pwmantest.* to 'pwman'@'localhost';" -uroot
+  - mongo pwmantest --eval 'db.addUser("tester", "12345678");'
 
 
 before_install:
 before_install:
   - sudo apt-get update -qq
   - sudo apt-get update -qq
@@ -17,6 +21,7 @@ before_install:
   - sudo apt-get install python-mysqldb
   - sudo apt-get install python-mysqldb
 # command to install dependencies, e.g. pip install -r requirements.txt --use-mirrors
 # command to install dependencies, e.g. pip install -r requirements.txt --use-mirrors
 install: 
 install: 
+  - "pip install pymongo==2.8"  
   - "pip install pymysql"   
   - "pip install pymysql"   
   - "pip install -r requirements.txt -r test_requirements.txt"
   - "pip install -r requirements.txt -r test_requirements.txt"
   - "pip install coveralls"
   - "pip install coveralls"

+ 1 - 1
Makefile

@@ -40,7 +40,7 @@ lint:
 
 
 test: install clean
 test: install clean
 	python setup.py test
 	python setup.py test
-	@rm -f pwman/tests/test.conf
+	@rm -f tests/test.conf
 
 
 test-all:
 test-all:
 	tox
 	tox

+ 1 - 1
README.md

@@ -5,7 +5,7 @@
 [![Documentation Status](https://readthedocs.org/projects/pwman3/badge/?version=latest)](https://readthedocs.org/projects/pwman3/?badge=latest)
 [![Documentation Status](https://readthedocs.org/projects/pwman3/badge/?version=latest)](https://readthedocs.org/projects/pwman3/?badge=latest)
 
 
 A nice command line password manager, which can use different database to store your passwords (currently, SQLite, MySQL, 
 A nice command line password manager, which can use different database to store your passwords (currently, SQLite, MySQL, 
-    and PostGresql are supported).  
+    and PostGresql and MongoDB are supported).  
 Pwman3 can also copy passwords to the clipboard without exposing them!
 Pwman3 can also copy passwords to the clipboard without exposing them!
 Besides managing and storing passwords, Pwman3 can also generate passwords using different algorithms. 
 Besides managing and storing passwords, Pwman3 can also generate passwords using different algorithms. 
 
 

+ 11 - 0
docs/source/install.rst

@@ -47,3 +47,14 @@ like to change your default Python interpreter to Python 3 serious, it is recomm
 that you export your database and re-import it to a new database created using Python 
 that you export your database and re-import it to a new database created using Python 
 3.X . 
 3.X . 
 
 
+Database versions 
+----------------- 
+
+The current version of Pwman3 is tested with Postgresql-9.3, MySQL-5.5,
+MongoDB 2.6.X and SQLite3. 
+
+The required python drivers are:
+ 
+ * pymysql  version 0.6.6 
+ * psycopg2 version 2.6
+ * pymongo version 2.8

+ 3 - 5
pwman/__init__.py

@@ -19,17 +19,15 @@
 # Copyright (C) 2006 Ivan Kelly <ivan@ivankelly.net>
 # Copyright (C) 2006 Ivan Kelly <ivan@ivankelly.net>
 # ============================================================================
 # ============================================================================
 import os
 import os
-import pkg_resources
 import argparse
 import argparse
 import sys
 import sys
 import re
 import re
 import colorama
 import colorama
 from pwman.util import config
 from pwman.util import config
-from pwman.data import factory
+from pwman.data.factory import check_db_version
 
 
 appname = "pwman3"
 appname = "pwman3"
 
 
-
 try:
 try:
     version = pkg_resources.get_distribution('pwman3').version
     version = pkg_resources.get_distribution('pwman3').version
 except pkg_resources.DistributionNotFound:  # pragma: no cover
 except pkg_resources.DistributionNotFound:  # pragma: no cover
@@ -123,5 +121,5 @@ def get_conf_options(args, OSX):
 
 
 
 
 def get_db_version(config, args):
 def get_db_version(config, args):
-    dbver = factory.check_db_version(config.get_value("Database", "dburi"))
-    return dbver
+    dburi = check_db_version(config.get_value("Database", "dburi"))
+    return dburi

+ 1 - 0
pwman/data/__init__.py

@@ -0,0 +1 @@
+from . import factory

+ 198 - 31
pwman/data/database.py

@@ -18,14 +18,12 @@
 # ============================================================================
 # ============================================================================
 # Copyright (C) 2006 Ivan Kelly <ivan@ivankelly.net>
 # Copyright (C) 2006 Ivan Kelly <ivan@ivankelly.net>
 # ============================================================================
 # ============================================================================
-
 from pwman.util.crypto_engine import CryptoEngine
 from pwman.util.crypto_engine import CryptoEngine
-
 __DB_FORMAT__ = 0.6
 __DB_FORMAT__ = 0.6
 
 
 
 
 class DatabaseException(Exception):
 class DatabaseException(Exception):
-    pass  # prage: no cover
+    pass  # pragma: no cover
 
 
 
 
 class Database(object):
 class Database(object):
@@ -47,8 +45,47 @@ class Database(object):
         else:
         else:
             self.get_user_password()
             self.get_user_password()
 
 
-    def close(self):
-        pass  # pragma: no cover
+    def _check_tables(self):
+        try:
+            self._cur.execute("SELECT 1 from DBVERSION")
+            version = self._cur.fetchone()
+            return version
+        except self.ProgrammingError:
+            self._con.rollback()
+
+    def _create_tables(self):
+
+        if self._check_tables():
+            return
+        try:
+            self._cur.execute("CREATE TABLE NODE(ID SERIAL PRIMARY KEY, "
+                              "USERNAME TEXT NOT NULL, "
+                              "PASSWORD TEXT NOT NULL, "
+                              "URL TEXT NOT NULL, "
+                              "NOTES TEXT NOT NULL"
+                              ")")
+
+            self._cur.execute("CREATE TABLE TAG"
+                              "(ID  SERIAL PRIMARY KEY,"
+                              "DATA VARCHAR(255) NOT NULL UNIQUE)")
+
+            self._cur.execute("CREATE TABLE LOOKUP ("
+                              "nodeid INTEGER NOT NULL REFERENCES NODE(ID),"
+                              "tagid INTEGER NOT NULL REFERENCES TAG(ID)"
+                              ")")
+
+            self._cur.execute("CREATE TABLE CRYPTO "
+                              "(SEED TEXT, DIGEST TEXT)")
+
+            self._cur.execute("CREATE TABLE DBVERSION("
+                              "VERSION TEXT NOT NULL)")
+
+            self._cur.execute("INSERT INTO DBVERSION VALUES(%s)",
+                              (self.dbversion,))
+
+            self._con.commit()
+        except self.ProgrammingError:  # pragma: no cover
+            self._con.rollback()
 
 
     def get_user_password(self):
     def get_user_password(self):
         """
         """
@@ -58,38 +95,168 @@ class Database(object):
         newkey = enc.changepassword()
         newkey = enc.changepassword()
         return self.savekey(newkey)
         return self.savekey(newkey)
 
 
-    def changepassword(self):
-        """
-        Change the databases password.
-        """
-        # TODO: call the converter here ...
-        # nodeids = self.listnodes()
-        # nodes = self.getnodes(nodeids)
-        # enc = CryptoEngine.get()
-        # oldkey = enc.get_cryptedkey()
-        # newkey = enc.changepassword()
-        # return newkey
+    def _clean_orphans(self):
+        clean = ("delete from TAG where not exists "
+                 "(select 'x' from LOOKUP l where l.TAGID = TAG.ID)")
+        self._cur.execute(clean)
+        self._con.commit()
 
 
-    def listtags(self, all=False):
-        pass  # pragma: no cover
+    def _get_node_tags(self, node):
+        sql = "SELECT tagid FROM LOOKUP WHERE NODEID = {}".format(self._sub)
+        self._cur.execute(sql, (str(node[0]),))
+        tagids = self._cur.fetchall()
+        if tagids:
+            sql = ("SELECT DATA FROM TAG WHERE ID IN"
+                   " ({})".format(','.join([self._sub]*len(tagids))))
+            tagids = [str(id[0]) for id in tagids]
+            self._cur.execute(sql, (tagids))
+            tags = self._cur.fetchall()
+            for t in tags:
+                yield t[0]
 
 
-    #def currenttags(self):
-    #    return self._filtertags
+    def _setnodetags(self, nodeid, tags):
+        for tag in tags:
+            tid = self._get_or_create_tag(tag)
+            self._update_tag_lookup(nodeid, tid)
 
 
-    def addnodes(self, nodes):
-        pass  # pragma: no cover
+    def _get_tag(self, tagcipher):
+        sql_search = "SELECT ID FROM TAG WHERE DATA = {}".format(self._sub)
+        self._cur.execute(sql_search, ([tagcipher]))
+        rv = self._cur.fetchone()
+        return rv
 
 
-    def editnode(self, id, node):
-        pass  # pragma: no cover
+    def _get_or_create_tag(self, tagcipher):
+        rv = self._get_tag(tagcipher)
+        if rv:
+            return rv[0]
+        else:
+            self._cur.execute(self._insert_tag_sql, ([tagcipher]))
+            try:
+                return self._cur.fetchone()[0]
+            except TypeError:
+                return self._cur.lastrowid
 
 
-    def removenodes(self, nodes):
-        pass  # pragma: no cover
+    def _update_tag_lookup(self, nodeid, tid):
+        sql_lookup = "INSERT INTO LOOKUP(nodeid, tagid) VALUES({}, {})".format(
+            self._sub, self._sub)
+        self._cur.execute(sql_lookup, (nodeid, tid))
+        self._con.commit()
 
 
-    def listnodes(self):
-        pass  # pragma: no cover
+    def getnodes(self, ids):
+        if ids:
+            sql = ("SELECT * FROM NODE WHERE ID IN ({})"
+                   "".format(','.join(self._sub for i in ids)))
+        else:
+            sql = "SELECT * FROM NODE"
+        self._cur.execute(sql, (ids))
+        nodes = self._cur.fetchall()
+        nodes_w_tags = []
+        for node in nodes:
+            tags = list(self._get_node_tags(node))
+            nodes_w_tags.append(list(node) + tags)
 
 
-    def savekey(self, key):
-        pass  # pragma: no cover
+        return nodes_w_tags
+
+    def listnodes(self, filter=None):
+        """return a list of node ids"""
+        if not filter:
+            sql_all = "SELECT ID FROM NODE"
+            self._cur.execute(sql_all)
+            ids = self._cur.fetchall()
+            return [id[0] for id in ids]
+        else:
+            tagid = self._get_tag(filter)
+            if not tagid:
+                return []  # pragma: no cover
+
+            self._cur.execute(self._list_nodes_sql, (tagid))
+            self._con.commit()
+            ids = self._cur.fetchall()
+            return [id[0] for id in ids]
+
+    def add_node(self, node):
+        node_tags = list(node)
+        node, tags = node_tags[:4], node_tags[-1]
+        self._cur.execute(self._add_node_sql, (node))
+        try:
+            nid = self._cur.fetchone()[0]
+        except TypeError:
+            nid = self._cur.lastrowid
+        self._setnodetags(nid, tags)
+        self._con.commit()
+
+    def listtags(self):
+        self._clean_orphans()
+        get_tags = "select DATA from TAG"
+        self._cur.execute(get_tags)
+        tags = self._cur.fetchall()
+        if tags:
+            return [t[0] for t in tags]
+        return []  # pragma: no cover
+
+    # TODO: add this to test of postgresql and mysql!
+    def editnode(self, nid, **kwargs):
+        tags = kwargs.pop('tags', None)
+        sql = ("UPDATE NODE SET {} WHERE ID = {} ".format(
+            ','.join(['{}={}'.format(k, self._sub) for k in list(kwargs)]),
+            self._sub))
+
+        self._cur.execute(sql, (list(kwargs.values()) + [nid]))
+        if tags:
+            # update all old node entries in lookup
+            # create new entries
+            # clean all old tags
+            sql_clean = "DELETE FROM LOOKUP WHERE NODEID={}".format(self._sub)
+            self._cur.execute(sql_clean, (str(nid),))
+            self._setnodetags(nid, tags)
+
+        self._con.commit()
+
+    def removenodes(self, nid):
+        # shall we do this also in the sqlite driver?
+        sql_clean = "DELETE FROM LOOKUP WHERE NODEID={}".format(self._sub)
+        self._cur.execute(sql_clean, nid)
+        sql_rm = "delete from NODE where ID = {}".format(self._sub)
+        self._cur.execute(sql_rm, nid)
+        self._con.commit()
+        self._con.commit()
+
+    def fetch_crypto_info(self):
+        self._cur.execute("SELECT * FROM CRYPTO")
+        row = self._cur.fetchone()
+        return row
+
+    def save_crypto_info(self, seed, digest):
+        """save the random seed and the digested key"""
+        self._cur.execute("DELETE  FROM CRYPTO")
+        self._cur.execute("INSERT INTO CRYPTO VALUES({}, {})".format(self._sub,
+                                                                     self._sub),
+                          (seed, digest))
+        self._con.commit()
 
 
     def loadkey(self):
     def loadkey(self):
-        pass  # pragma: no cover
+        """
+        return _keycrypted
+        """
+        sql = "SELECT * FROM CRYPTO"
+        try:
+            self._cur.execute(sql)
+            seed, digest = self._cur.fetchone()
+            return seed + u'$6$' + digest
+        except TypeError:  # pragma: no cover
+            return None
+
+    def savekey(self, key):
+        salt, digest = key.split('$6$')
+        sql = "INSERT INTO CRYPTO(SEED, DIGEST) VALUES({},{})".format(self._sub,
+                                                                      self._sub)
+        self._cur.execute("DELETE FROM CRYPTO")
+        self._cur.execute(sql, (salt, digest))
+        self._digest = digest.encode('utf-8')
+        self._salt = salt.encode('utf-8')
+        self._con.commit()
+
+    def close(self):  # pragma: no cover
+        self._clean_orphans()
+        self._cur.close()
+        self._con.close()

+ 21 - 0
pwman/data/drivers/__init__.py

@@ -0,0 +1,21 @@
+try:
+    from .sqlite import SQLite
+except ImportError:
+    SQLite = None
+
+try:
+    from .postgresql import PostgresqlDatabase
+except ImportError:
+    PostgresqlDatabase = None
+
+try:
+    from .mysql import MySQLDatabase
+except ImportError:
+    MySQLDatabase = None
+
+try:
+    from .mongodb import MongoDB
+except ImportError:
+    MongoDB = None
+
+__all__ = [SQLite, PostgresqlDatabase, MySQLDatabase, MongoDB]

+ 113 - 0
pwman/data/drivers/mongodb.py

@@ -0,0 +1,113 @@
+# ============================================================================
+# This file is part of Pwman3.
+#
+# Pwman3 is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License, version 2
+# as published by the Free Software Foundation;
+#
+# Pwman3 is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Pwman3; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
+# ============================================================================
+# Copyright (C) 2015 Oz Nahum Tiram <nahumoz@gmail.com>
+# ============================================================================
+
+from pwman.data.database import Database, __DB_FORMAT__
+import pymongo
+
+
+class MongoDB(Database):
+
+    @classmethod
+    def check_db_version(cls, dburi):
+        return __DB_FORMAT__
+
+    def __init__(self, mongodb_uri, dbformat=__DB_FORMAT__):
+        self.uri = mongodb_uri.geturl()
+
+    def _open(self):
+        self._con = pymongo.Connection(self.uri)
+        self._db = self._con.get_default_database()
+
+        counters = self._db.counters.find()
+        if not counters.count():
+            self._db.counters.insert({'_id': 'nodeid', 'seq': 0})
+
+    def _get_next_node_id(self):
+        # for newer pymongo versions ...
+        # return_document=ReturnDocument.AFTER
+        nodeid = self._db.counters.find_and_modify(
+            {'_id': 'nodeid'}, {'$inc': {'seq': 1}}, new=True,
+            fields={'seq': 1, '_id': 0})
+        return nodeid['seq']
+
+    def getnodes(self, ids):
+        if ids:
+            ids = list(map(int, ids))
+            node_dicts = self._db.nodes.find({'_id': {'$in': ids}})
+        else:
+            node_dicts = self._db.nodes.find({})
+        nodes = []
+        for node in node_dicts:
+            n = [node['_id'],
+                 node['user'],
+                 node['password'],
+                 node['url'],
+                 node['notes']]
+
+            [n.append(t) for t in node['tags']]
+            nodes.append(n)
+
+        return nodes
+
+    def listnodes(self, filter=None):
+        if not filter:
+            nodes = self._db.nodes.find({}, {'_id': 1})
+
+        else:
+            nodes = self._db.nodes.find({"tags": {'$in': [filter]}}, {'_id': 1})
+
+        return [node['_id'] for node in list(nodes)]
+
+    def add_node(self, node):
+        nid = self._get_next_node_id()
+        node = node.to_encdict()
+        node['_id'] = nid
+        self._db.nodes.insert(node)
+        return nid
+
+    def listtags(self):
+        tags = self._db.nodes.distinct('tags')
+        return tags
+
+    def editnode(self, nid, **kwargs):
+        self._db.nodes.find_and_modify({'_id': nid}, kwargs)
+
+    def removenodes(self, nid):
+        nid = list(map(int, nid))
+        self._db.nodes.remove({'_id': {'$in': nid}})
+
+    def fetch_crypto_info(self):
+        pass
+
+    def savekey(self, key):
+        coll = self._db['crypto']
+        salt, digest = key.split('$6$')
+        coll.insert({'salt': salt, 'key': digest})
+
+    def loadkey(self):
+        coll = self._db['crypto']
+        try:
+            key = coll.find_one({}, {'_id': 0})
+            key = key['salt'] + '$6$' + key['key']
+        except TypeError:
+            key = None
+        return key
+
+    def close(self):
+        self._con.close()

+ 13 - 180
pwman/data/drivers/mysql.py

@@ -16,19 +16,16 @@
 # ============================================================================
 # ============================================================================
 # Copyright (C) 2012-2015 Oz Nahum <nahumoz@gmail.com>
 # Copyright (C) 2012-2015 Oz Nahum <nahumoz@gmail.com>
 # ============================================================================
 # ============================================================================
-#mysql -u root -p
-#create database pwmantest
-#create user 'pwman'@'localhost' IDENTIFIED BY '123456';
-#grant all on pwmantest.* to 'pwman'@'localhost';
+# mysql -u root -p
+# create database pwmantest
+# create user 'pwman'@'localhost' IDENTIFIED BY '123456';
+# grant all on pwmantest.* to 'pwman'@'localhost';
 
 
 """MySQL Database implementation."""
 """MySQL Database implementation."""
-from __future__ import print_function
 from pwman.data.database import Database, __DB_FORMAT__
 from pwman.data.database import Database, __DB_FORMAT__
 
 
 import pymysql as mysql
 import pymysql as mysql
 mysql.install_as_MySQLdb()
 mysql.install_as_MySQLdb()
-#else:
-#    import MySQLdb as mysql
 
 
 
 
 class MySQLDatabase(Database):
 class MySQLDatabase(Database):
@@ -53,9 +50,18 @@ class MySQLDatabase(Database):
         except mysql.ProgrammingError:
         except mysql.ProgrammingError:
             con.rollback()
             con.rollback()
 
 
+        return str(__DB_FORMAT__)
+
     def __init__(self, mysqluri, dbformat=__DB_FORMAT__):
     def __init__(self, mysqluri, dbformat=__DB_FORMAT__):
         self.dburi = mysqluri
         self.dburi = mysqluri
         self.dbversion = dbformat
         self.dbversion = dbformat
+        self._sub = "%s"
+        self._list_nodes_sql = "SELECT NODEID FROM LOOKUP WHERE TAGID = %s "
+        self._add_node_sql = ("INSERT INTO NODE(USERNAME, PASSWORD, URL, "
+                              "NOTES) "
+                              "VALUES(%s, %s, %s, %s)")
+        self._insert_tag_sql = "INSERT INTO TAG(DATA) VALUES(%s)"
+        self.ProgrammingError = mysql.ProgrammingError
 
 
     def _open(self):
     def _open(self):
 
 
@@ -71,176 +77,3 @@ class MySQLDatabase(Database):
                                   db=self.dburi.path.lstrip('/'))
                                   db=self.dburi.path.lstrip('/'))
         self._cur = self._con.cursor()
         self._cur = self._con.cursor()
         self._create_tables()
         self._create_tables()
-
-    def _create_tables(self):
-
-        try:
-            self._cur.execute("SELECT 1 from DBVERSION")
-            version = self._cur.fetchone()
-            if version:
-                return
-        except mysql.ProgrammingError:
-            self._con.rollback()
-
-        try:
-            self._cur.execute("CREATE TABLE NODE(ID SERIAL PRIMARY KEY, "
-                              "USERNAME TEXT NOT NULL, "
-                              "PASSWORD TEXT NOT NULL, "
-                              "URL TEXT NOT NULL, "
-                              "NOTES TEXT NOT NULL"
-                              ")")
-
-            self._cur.execute("CREATE TABLE TAG"
-                              "(ID  SERIAL PRIMARY KEY,"
-                              "DATA VARCHAR(255) NOT NULL UNIQUE)")
-
-            self._cur.execute("CREATE TABLE LOOKUP ("
-                              "nodeid INTEGER NOT NULL REFERENCES NODE(ID),"
-                              "tagid INTEGER NOT NULL REFERENCES TAG(ID)"
-                              ")")
-
-            self._cur.execute("CREATE TABLE CRYPTO "
-                              "(SEED TEXT, DIGEST TEXT)")
-
-            self._cur.execute("CREATE TABLE DBVERSION("
-                              "VERSION TEXT NOT NULL "
-                              ")")
-
-            self._cur.execute("INSERT INTO DBVERSION VALUES(%s)",
-                              (self.dbversion,))
-
-            self._con.commit()
-        except mysql.ProgrammingError:  # pragma: no cover
-            self._con.rollback()
-
-    def getnodes(self, ids):
-        if ids:
-            sql = ("SELECT * FROM NODE WHERE ID IN ({})"
-                   "".format(','.join('%s' for i in ids)))
-        else:
-            sql = "SELECT * FROM NODE"
-        self._cur.execute(sql, (ids))
-        nodes = self._cur.fetchall()
-        nodes_w_tags = []
-        for node in nodes:
-            tags = list(self._get_node_tags(node))
-            nodes_w_tags.append(list(node) + tags)
-
-        return nodes_w_tags
-
-    def add_node(self, node):
-        sql = ("INSERT INTO NODE(USERNAME, PASSWORD, URL, NOTES)"
-               "VALUES(%s, %s, %s, %s)")
-        node_tags = list(node)
-        node, tags = node_tags[:4], node_tags[-1]
-        self._cur.execute(sql, (node))
-        nid = self._cur.lastrowid
-        self._setnodetags(nid, tags)
-        self._con.commit()
-
-    def _get_node_tags(self, node):
-        sql = "SELECT tagid FROM LOOKUP WHERE NODEID = %s"
-        self._cur.execute(sql, (str(node[0]),))
-        tagids = self._cur.fetchall()
-        if tagids:
-            sql = ("SELECT DATA FROM TAG WHERE ID IN (%s)"
-                   "" % ','.join(['%s']*len(tagids)))
-            tagids = [str(id[0]) for id in tagids]
-            self._cur.execute(sql, (tagids))
-            tags = self._cur.fetchall()
-            for t in tags:
-                yield t[0]
-
-    def _setnodetags(self, nodeid, tags):
-        for tag in tags:
-            tid = self._get_or_create_tag(tag)
-            self._update_tag_lookup(nodeid, tid)
-
-    def _get_tag(self, tagcipher):
-        sql_search = "SELECT ID FROM TAG WHERE DATA = %s"
-        self._cur.execute(sql_search, ([tagcipher]))
-        rv = self._cur.fetchone()
-        return rv
-
-    def _get_or_create_tag(self, tagcipher):
-        rv = self._get_tag(tagcipher)
-        if rv:
-            return rv[0]
-        else:
-            sql_insert = "INSERT INTO TAG(DATA) VALUES(%s)"
-            self._cur.execute(sql_insert, ([tagcipher]))
-            return self._cur.lastrowid
-
-    def _update_tag_lookup(self, nodeid, tid):
-        sql_lookup = "INSERT INTO LOOKUP(nodeid, tagid) VALUES(%s, %s)"
-        self._cur.execute(sql_lookup, (nodeid, tid))
-        self._con.commit()
-
-    def fetch_crypto_info(self):
-        self._cur.execute("SELECT * FROM CRYPTO")
-        row = self._cur.fetchone()
-        return row
-
-    def listtags(self):
-        self._clean_orphans()
-        get_tags = "select DATA from TAG"
-        self._cur.execute(get_tags)
-        tags = self._cur.fetchall()
-        if tags:
-            return [t[0] for t in tags]
-        return []  # pragma: no cover
-
-    def listnodes(self, filter=None):
-        if not filter:
-            sql_all = "SELECT ID FROM NODE"
-            self._cur.execute(sql_all)
-            ids = self._cur.fetchall()
-            return [id[0] for id in ids]
-        else:
-            tagid = self._get_tag(filter)
-            if not tagid:
-                return []  # pragma: no cover
-
-            sql_filter = "SELECT NODEID FROM LOOKUP WHERE TAGID = %s "
-            self._cur.execute(sql_filter, (tagid))
-            self._con.commit()
-            ids = self._cur.fetchall()
-            return [id[0] for id in ids]
-
-    def save_crypto_info(self, seed, digest):
-        """save the random seed and the digested key"""
-        self._cur.execute("DELETE  FROM CRYPTO")
-        self._cur.execute("INSERT INTO CRYPTO VALUES(%s, %s)", (seed, digest))
-        self._con.commit()
-
-    def loadkey(self):
-        sql = "SELECT * FROM CRYPTO"
-        try:
-            self._cur.execute(sql)
-            seed, digest = self._cur.fetchone()
-            return seed + u'$6$' + digest
-        except TypeError:  # pragma: no cover
-            return None
-
-    def _clean_orphans(self):
-        clean = ("delete from TAG where not exists "
-                 "(select 'x' from LOOKUP l where l.TAGID = TAG.ID)")
-        self._cur.execute(clean)
-
-    def removenodes(self, nid):
-        # shall we do this also in the sqlite driver?
-        sql_clean = "DELETE FROM LOOKUP WHERE NODEID=%s"
-        self._cur.execute(sql_clean, nid)
-        sql_rm = "delete from NODE where ID = %s"
-        self._cur.execute(sql_rm, nid)
-        self._con.commit()
-        self._con.commit()
-
-    def savekey(self, key):
-        salt, digest = key.split('$6$')
-        sql = "INSERT INTO CRYPTO(SEED, DIGEST) VALUES(%s,%s)"
-        self._cur.execute("DELETE FROM CRYPTO")
-        self._cur.execute(sql, (salt, digest))
-        self._digest = digest.encode('utf-8')
-        self._salt = salt.encode('utf-8')
-        self._con.commit()

+ 7 - 194
pwman/data/drivers/postgresql.py

@@ -54,6 +54,7 @@ class PostgresqlDatabase(Database):
             return version[-1]
             return version[-1]
         except pg.ProgrammingError:
         except pg.ProgrammingError:
             con.rollback()
             con.rollback()
+            return __DB_FORMAT__
 
 
     def __init__(self, pgsqluri, dbformat=__DB_FORMAT__):
     def __init__(self, pgsqluri, dbformat=__DB_FORMAT__):
         """
         """
@@ -61,203 +62,15 @@ class PostgresqlDatabase(Database):
         """
         """
         self._pgsqluri = pgsqluri
         self._pgsqluri = pgsqluri
         self.dbversion = dbformat
         self.dbversion = dbformat
+        self._sub = "%s"
+        self._list_nodes_sql = "SELECT NODEID FROM LOOKUP WHERE TAGID = %s "
+        self._add_node_sql = ('INSERT INTO NODE(USERNAME, PASSWORD, URL, '
+                              'NOTES) VALUES(%s, %s, %s, %s) RETURNING ID')
+        self._insert_tag_sql = "INSERT INTO TAG(DATA) VALUES(%s) RETURNING ID"
+        self.ProgrammingError = pg.ProgrammingError
 
 
     def _open(self):
     def _open(self):
 
 
         self._con = pg.connect(self._pgsqluri.geturl())
         self._con = pg.connect(self._pgsqluri.geturl())
         self._cur = self._con.cursor()
         self._cur = self._con.cursor()
         self._create_tables()
         self._create_tables()
-
-    def listnodes(self, filter=None):
-        if not filter:
-            sql_all = "SELECT ID FROM NODE"
-            self._cur.execute(sql_all)
-            ids = self._cur.fetchall()
-            return [id[0] for id in ids]
-        else:
-            tagid = self._get_tag(filter)
-            if not tagid:
-                return []  # pragma: no cover
-
-            sql_filter = "SELECT NODEID FROM LOOKUP WHERE TAGID = %s "
-            self._cur.execute(sql_filter, (tagid))
-            self._con.commit()
-            ids = self._cur.fetchall()
-            return [id[0] for id in ids]
-
-    def listtags(self):
-        self._clean_orphans()
-        get_tags = "select data from tag"
-        self._cur.execute(get_tags)
-        tags = self._cur.fetchall()
-        if tags:
-            return [t[0] for t in tags]
-        return []  # pragma: no cover
-
-    def _create_tables(self):
-
-        try:
-            self._cur.execute("SELECT 1 from DBVERSION")
-            version = self._cur.fetchone()
-            if version:
-                return
-        except pg.ProgrammingError:
-            self._con.rollback()
-
-        try:
-            self._cur.execute("CREATE TABLE NODE(ID SERIAL PRIMARY KEY, "
-                              "USERNAME TEXT NOT NULL, "
-                              "PASSWORD TEXT NOT NULL, "
-                              "URL TEXT NOT NULL, "
-                              "NOTES TEXT NOT NULL"
-                              ")")
-
-            self._cur.execute("CREATE TABLE TAG"
-                              "(ID SERIAL PRIMARY KEY,"
-                              "DATA TEXT NOT NULL UNIQUE)")
-
-            self._cur.execute("CREATE TABLE LOOKUP ("
-                              "nodeid INTEGER NOT NULL REFERENCES NODE(ID),"
-                              "tagid INTEGER NOT NULL REFERENCES TAG(ID)"
-                              ")")
-
-            self._cur.execute("CREATE TABLE CRYPTO "
-                              "(SEED TEXT, DIGEST TEXT)")
-
-            self._cur.execute("CREATE TABLE DBVERSION("
-                              "VERSION TEXT NOT NULL DEFAULT {}"
-                              ")".format(__DB_FORMAT__))
-
-            self._cur.execute("INSERT INTO DBVERSION VALUES(%s)",
-                              (self.dbversion,))
-
-            self._con.commit()
-        except pg.ProgrammingError:  # pragma: no cover
-            self._con.rollback()
-
-    def fetch_crypto_info(self):
-        self._cur.execute("SELECT * FROM CRYPTO")
-        row = self._cur.fetchone()
-        return row
-
-    def save_crypto_info(self, seed, digest):
-        """save the random seed and the digested key"""
-        self._cur.execute("DELETE  FROM CRYPTO")
-        self._cur.execute("INSERT INTO CRYPTO VALUES(%s, %s)", (seed, digest))
-        self._con.commit()
-
-    def add_node(self, node):
-        sql = ("INSERT INTO NODE(USERNAME, PASSWORD, URL, NOTES)"
-               "VALUES(%s, %s, %s, %s) RETURNING ID")
-        node_tags = list(node)
-        node, tags = node_tags[:4], node_tags[-1]
-        self._cur.execute(sql, (node))
-        nid = self._cur.fetchone()[0]
-        self._setnodetags(nid, tags)
-        self._con.commit()
-
-    def _get_tag(self, tagcipher):
-        sql_search = "SELECT ID FROM TAG WHERE DATA = %s"
-        self._cur.execute(sql_search, ([tagcipher]))
-        rv = self._cur.fetchone()
-        return rv
-
-    def _get_or_create_tag(self, tagcipher):
-        rv = self._get_tag(tagcipher)
-        if rv:
-            return rv[0]
-        else:
-            sql_insert = "INSERT INTO TAG(DATA) VALUES(%s) RETURNING ID"
-            self._cur.execute(sql_insert, ([tagcipher]))
-            rid = self._cur.fetchone()[0]
-            return rid
-
-    def _update_tag_lookup(self, nodeid, tid):
-        sql_lookup = "INSERT INTO LOOKUP(nodeid, tagid) VALUES(%s, %s)"
-        self._cur.execute(sql_lookup, (nodeid, tid))
-        self._con.commit()
-
-    def _setnodetags(self, nodeid, tags):
-        for tag in tags:
-            tid = self._get_or_create_tag(tag)
-            self._update_tag_lookup(nodeid, tid)
-
-    def _get_node_tags(self, node):
-        sql = "SELECT tagid FROM LOOKUP WHERE NODEID = %s"
-        self._cur.execute(sql, (str(node[0]),))
-        tagids = self._cur.fetchall()
-        if tagids:
-            sql = ("SELECT DATA FROM TAG WHERE ID IN (%s)"
-                   "" % ','.join(['%s']*len(tagids)))
-            tagids = [str(id[0]) for id in tagids]
-            self._cur.execute(sql, (tagids))
-            tags = self._cur.fetchall()
-            for t in tags:
-                yield t[0]
-
-    def getnodes(self, ids):
-        if ids:
-            sql = ("SELECT * FROM NODE WHERE ID IN ({})"
-                   "".format(','.join('%s' for i in ids)))
-        else:
-            sql = "SELECT * FROM NODE"
-        self._cur.execute(sql, (ids))
-        nodes = self._cur.fetchall()
-        nodes_w_tags = []
-        for node in nodes:
-            tags = list(self._get_node_tags(node))
-            nodes_w_tags.append(list(node) + tags)
-
-        return nodes_w_tags
-
-    def editnode(self, nid, **kwargs):  # pragma: no cover
-        tags = kwargs.pop('tags', None)
-        sql = ("UPDATE NODE SET %s WHERE ID = %%s "
-               "" % ','.join('%s=%%s' % k for k in list(kwargs)))
-        self._cur.execute(sql, (list(kwargs.values()) + [nid]))
-        if tags:
-            # update all old node entries in lookup
-            # create new entries
-            # clean all old tags
-            sql_clean = "DELETE FROM LOOKUP WHERE NODEID=?"
-            self._cur.execute(sql_clean, (str(nid),))
-            self._setnodetags(nid, tags)
-
-        self._con.commit()
-
-    def removenodes(self, nid):
-        # shall we do this also in the sqlite driver?
-        sql_clean = "DELETE FROM LOOKUP WHERE NODEID=%s"
-        self._cur.execute(sql_clean, nid)
-        sql_rm = "delete from node where id = %s"
-        self._cur.execute(sql_rm, nid)
-        self._con.commit()
-
-    def _clean_orphans(self):
-        clean = ("delete from tag where not exists "
-                 "(select 'x' from lookup l where l.tagid = tag.id)")
-        self._cur.execute(clean)
-        self._con.commit()
-
-    def savekey(self, key):
-        salt, digest = key.split('$6$')
-        sql = "INSERT INTO CRYPTO(SEED, DIGEST) VALUES(%s,%s)"
-        self._cur.execute("DELETE FROM CRYPTO")
-        self._cur.execute(sql, (salt, digest))
-        self._digest = digest.encode('utf-8')
-        self._salt = salt.encode('utf-8')
-        self._con.commit()
-
-    def loadkey(self):
-        sql = "SELECT * FROM CRYPTO"
-        try:
-            self._cur.execute(sql)
-            seed, digest = self._cur.fetchone()
-            return seed + u'$6$' + digest
-        except TypeError:  # pragma: no cover
-            return None
-
-    def close(self):  # pragma: no cover
-        self._clean_orphans()
-        self._cur.close()
-        self._con.close()

+ 12 - 156
pwman/data/drivers/sqlite.py

@@ -20,8 +20,8 @@
 # ============================================================================
 # ============================================================================
 
 
 """SQLite Database implementation."""
 """SQLite Database implementation."""
-from pwman.data.database import Database
-from pwman.data.database import __DB_FORMAT__
+from __future__ import print_function
+from ..database import Database, __DB_FORMAT__
 import sqlite3 as sqlite
 import sqlite3 as sqlite
 
 
 
 
@@ -32,7 +32,11 @@ class SQLite(Database):
         """
         """
         check the database version.
         check the database version.
         """
         """
-        con = sqlite.connect(fname)
+        try:
+            con = sqlite.connect(fname)
+        except sqlite.OperationalError as E:
+            print("could not open %s" % fname)
+            raise E
         cur = con.cursor()
         cur = con.cursor()
         cur.execute("PRAGMA TABLE_INFO(DBVERSION)")
         cur.execute("PRAGMA TABLE_INFO(DBVERSION)")
         row = cur.fetchone()
         row = cur.fetchone()
@@ -45,37 +49,17 @@ class SQLite(Database):
         """Initialise SQLitePwmanDatabase instance."""
         """Initialise SQLitePwmanDatabase instance."""
         self._filename = filename
         self._filename = filename
         self.dbformat = dbformat
         self.dbformat = dbformat
+        self._add_node_sql = ("INSERT INTO NODE(USER, PASSWORD, URL, NOTES)"
+                              "VALUES(?, ?, ?, ?)")
+        self._list_nodes_sql = "SELECT NODEID FROM LOOKUP WHERE TAGID = ? "
+        self._insert_tag_sql = "INSERT INTO TAG(DATA) VALUES(?)"
+        self._sub = '?'
 
 
     def _open(self):
     def _open(self):
         self._con = sqlite.connect(self._filename)
         self._con = sqlite.connect(self._filename)
         self._cur = self._con.cursor()
         self._cur = self._con.cursor()
         self._create_tables()
         self._create_tables()
 
 
-    def listnodes(self, filter=None):
-        """return a list of node ids"""
-        if not filter:
-            sql_all = "SELECT ID FROM NODE"
-            self._cur.execute(sql_all)
-            ids = self._cur.fetchall()
-            return [id[0] for id in ids]
-        else:
-            tagid = self._get_tag(filter)
-            if not tagid:
-                return []
-            sql_filter = "SELECT NODEID FROM LOOKUP WHERE TAGID = ? "
-            self._cur.execute(sql_filter, (tagid))
-            ids = self._cur.fetchall()
-            return [id[0] for id in ids]
-
-    def listtags(self):
-        self._clean_orphans()
-        get_tags = "select data from tag"
-        self._cur.execute(get_tags)
-        tags = self._cur.fetchall()
-        if tags:
-            return [t[0] for t in tags]
-        return []
-
     def _create_tables(self):
     def _create_tables(self):
         self._cur.execute("PRAGMA TABLE_INFO(NODE)")
         self._cur.execute("PRAGMA TABLE_INFO(NODE)")
         if self._cur.fetchone() is not None:
         if self._cur.fetchone() is not None:
@@ -113,131 +97,3 @@ class SQLite(Database):
         except Exception as e:  # pragma: no cover
         except Exception as e:  # pragma: no cover
             self._con.rollback()
             self._con.rollback()
             raise e
             raise e
-
-    def fetch_crypto_info(self):
-        self._cur.execute("SELECT * FROM CRYPTO")
-        keyrow = self._cur.fetchone()
-        return keyrow
-
-    def save_crypto_info(self, seed, digest):
-        """save the random seed and the digested key"""
-        self._cur.execute("DELETE  FROM CRYPTO")
-        self._cur.execute("INSERT INTO CRYPTO VALUES(?, ?)", [seed, digest])
-        self._con.commit()
-
-    def add_node(self, node):
-        sql = ("INSERT INTO NODE(USER, PASSWORD, URL, NOTES)"
-               "VALUES(?, ?, ?, ?)")
-        node_tags = list(node)
-        node, tags = node_tags[:4], node_tags[-1]
-        self._cur.execute(sql, (node))
-        self._setnodetags(self._cur.lastrowid, tags)
-        self._con.commit()
-
-    def _get_tag(self, tagcipher):
-        sql_search = "SELECT ID FROM TAG WHERE DATA = ?"
-        self._cur.execute(sql_search, ([tagcipher]))
-        rv = self._cur.fetchone()
-        return rv
-
-    def _get_or_create_tag(self, tagcipher):
-        rv = self._get_tag(tagcipher)
-        if rv:
-            return rv[0]
-        else:
-            sql_insert = "INSERT INTO TAG(DATA) VALUES(?)"
-            self._cur.execute(sql_insert, ([tagcipher]))
-            return self._cur.lastrowid
-
-    def _update_tag_lookup(self, nodeid, tid):
-        sql_lookup = "INSERT INTO LOOKUP(nodeid, tagid) VALUES(?,?)"
-        self._cur.execute(sql_lookup, (nodeid, tid))
-        self._con.commit()
-
-    def _setnodetags(self, nodeid, tags):
-        for tag in tags:
-            tid = self._get_or_create_tag(tag)
-            self._update_tag_lookup(nodeid, tid)
-
-    def _get_node_tags(self, node):
-        sql = "SELECT tagid FROM LOOKUP WHERE NODEID = ?"
-        tagids = self._cur.execute(sql, (str(node[0]),)).fetchall()
-        sql = ("SELECT DATA FROM TAG WHERE ID IN (%s)"
-               "" % ','.join('?'*len(tagids)))
-        tagids = [str(id[0]) for id in tagids]
-        self._cur.execute(sql, (tagids))
-        tags = self._cur.fetchall()
-        for t in tags:
-            yield t[0]
-
-    def getnodes(self, ids):
-        """
-        get nodes as raw ciphertext
-        """
-        if ids:
-            sql = ("SELECT * FROM NODE WHERE ID IN ({})"
-                   "".format(','.join('?'*len(ids))))
-        else:
-            sql = "SELECT * FROM NODE"
-
-        self._cur.execute(sql, (ids))
-        nodes = self._cur.fetchall()
-        nodes_w_tags = []
-        for node in nodes:
-            tags = list(self._get_node_tags(node))
-            nodes_w_tags.append(list(node) + tags)
-
-        return nodes_w_tags
-
-    def editnode(self, nid, **kwargs):
-        tags = kwargs.pop('tags', None)
-        sql = ("UPDATE NODE SET %s WHERE ID = ? "
-               "" % ','.join('%s=?' % k for k in list(kwargs)))
-        self._cur.execute(sql, (list(kwargs.values()) + [nid]))
-        if tags:
-            # update all old node entries in lookup
-            # create new entries
-            # clean all old tags
-            sql_clean = "DELETE FROM LOOKUP WHERE NODEID=?"
-            self._cur.execute(sql_clean, (str(nid),))
-            self._setnodetags(nid, tags)
-
-        self._con.commit()
-
-    def removenodes(self, nids):
-        sql_rm = "delete from node where id in (%s)" % ','.join('?'*len(nids))
-        self._cur.execute(sql_rm, (nids))
-        self._con.commit()
-
-    def _clean_orphans(self):
-        clean = ("delete from tag where not exists "
-                 "(select 'x' from lookup l where l.tagid = tag.id)")
-
-        self._cur.execute(clean)
-        self._con.commit()
-
-    def savekey(self, key):
-        salt, digest = key.split('$6$')
-        sql = "INSERT INTO CRYPTO(SEED, DIGEST) VALUES(?,?)"
-        self._cur.execute("DELETE FROM CRYPTO")
-        self._cur.execute(sql, (salt, digest))
-        self._digest = digest.encode('utf-8')
-        self._salt = salt.encode('utf-8')
-        self._con.commit()
-
-    def loadkey(self):
-        # TODO: rename this method!
-        """
-        return _keycrypted
-        """
-        sql = "SELECT * FROM CRYPTO"
-        try:
-            seed, digest = self._cur.execute(sql).fetchone()
-            return seed + u'$6$' + digest
-        except TypeError:
-            return None
-
-    def close(self):
-        self._clean_orphans()
-        self._cur.close()
-        self._con.close()

+ 48 - 47
pwman/data/factory.py

@@ -14,22 +14,9 @@
 # along with Pwman3; if not, write to the Free Software
 # along with Pwman3; if not, write to the Free Software
 # Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
 # Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
 # ============================================================================
 # ============================================================================
-# Copyright (C) 2012-2014 Oz Nahum Tiram <nahumoz@gmail.com>
+# Copyright (C) 2012-2015 Oz Nahum Tiram <nahumoz@gmail.com>
 # ============================================================================
 # ============================================================================
-# Copyright (C) 2006 Ivan Kelly <ivan@ivankelly.net>
-# ============================================================================
-
-"""
-Factory to create Database instances
-A Generic interface for all DB engines.
-Usage:
 
 
-import pwman.data.factory as DBFactory
-
-db = DBFactory.create(params)
-db.open()
-.....
-"""
 import sys
 import sys
 if sys.version_info.major > 2:  # pragma: no cover
 if sys.version_info.major > 2:  # pragma: no cover
     from urllib.parse import urlparse
     from urllib.parse import urlparse
@@ -38,47 +25,61 @@ else:
 
 
 import os
 import os
 
 
-from pwman.data.database import DatabaseException, __DB_FORMAT__
-from pwman.data.drivers import sqlite
-from pwman.data.drivers import postgresql
+from pwman.data.database import DatabaseException
+from pwman.data import drivers
+
+
+def parse_sqlite_uri(dburi):
+    filename = os.path.abspath(dburi.path)
+    return filename
+
+
+def parse_postgres_uri(dburi):
+    return dburi.geturl()
+
+
+def no_parse_uri(dburi):
+    return dburi
+
+
+class_db_map = {'sqlite':
+                ['SQLite', parse_sqlite_uri],
+                'postgresql': ['PostgresqlDatabase', parse_postgres_uri,
+                               'python-psycopg2'],
+                'mysql': ['MySQLDatabase', no_parse_uri, 'pymysql'],
+                'mongodb': ['MongoDB', no_parse_uri, 'pymongo']
+                }
+create_db_map = {'sqlite':
+                 ['SQLite', parse_sqlite_uri],
+                 'postgresql': ['PostgresqlDatabase', no_parse_uri,
+                                'python-psycopg2'],
+                 'mysql': ['MySQLDatabase', no_parse_uri, 'pymysql'],
+                 'mongodb': ['MongoDB', no_parse_uri, 'pymongo']
+                 }
 
 
 
 
 def check_db_version(dburi):
 def check_db_version(dburi):
 
 
-    ver = str(__DB_FORMAT__)
     dburi = urlparse(dburi)
     dburi = urlparse(dburi)
     dbtype = dburi.scheme
     dbtype = dburi.scheme
-    filename = os.path.abspath(dburi.path)
-    if dbtype == "sqlite":
-        ver = sqlite.SQLite.check_db_version(filename)
-    if dbtype == "postgresql":
-        #  ver = postgresql.PostgresqlDatabase.check_db_version(dburi)
-        ver = postgresql.PostgresqlDatabase.check_db_version(dburi.geturl())
-
-    return float(ver.strip("\'"))
+    try:
+        cls = getattr(drivers, class_db_map[dbtype][0])
+        ver = cls.check_db_version(class_db_map[dbtype][1](dburi))
+        return ver
+    except AttributeError:
+        raise DatabaseException(
+            '%s not installed? ' % class_db_map[dbtype][-1])
 
 
 
 
 def createdb(dburi, version):
 def createdb(dburi, version):
+
     dburi = urlparse(dburi)
     dburi = urlparse(dburi)
     dbtype = dburi.scheme
     dbtype = dburi.scheme
-    filename = dburi.path
-
-    if dbtype == "sqlite":
-        from pwman.data.drivers import sqlite
-        db = sqlite.SQLite(filename, dbformat=version)
-
-    elif dbtype == "postgresql":
-        try:
-            from pwman.data.drivers import postgresql
-            db = postgresql.PostgresqlDatabase(dburi)
-        except ImportError:  # pragma: no cover
-            raise DatabaseException("python-psycopg2 not installed")
-    elif dbtype == "mysql":  # pragma: no cover
-        try:
-            from pwman.data.drivers import mysql
-            db = mysql.MySQLDatabase()
-        except ImportError:
-            raise DatabaseException("python-mysqldb not installed")
-    else:
-        raise DatabaseException("Unknown database type specified")
-    return db
+    try:
+        cls = getattr(drivers, create_db_map[dbtype][0])
+        return cls(create_db_map[dbtype][1](dburi))
+    except AttributeError:
+        raise DatabaseException(
+            '%s not installed? ' % class_db_map[dbtype][-1])
+    except KeyError:
+        raise DatabaseException('Unknown database [%s] given ...' % (dbtype))

+ 6 - 2
pwman/ui/baseui.py

@@ -213,12 +213,11 @@ class BaseCommands(HelpUIMixin, AliasesMixin):
 
 
     def do_copy(self, args):  # pragma: no cover
     def do_copy(self, args):  # pragma: no cover
         """copy item to clipboard"""
         """copy item to clipboard"""
-        if not self._xsel:
+        if not self.hasxsel:
             return
             return
         if not args.isdigit():
         if not args.isdigit():
             print("Copy accepts only IDs ...")
             print("Copy accepts only IDs ...")
             return
             return
-
         ids = args.split()
         ids = args.split()
         if len(ids) > 1:
         if len(ids) > 1:
             print("Can copy only 1 password at a time...")
             print("Can copy only 1 password at a time...")
@@ -249,8 +248,13 @@ class BaseCommands(HelpUIMixin, AliasesMixin):
             url = ce.decrypt(node[3])
             url = ce.decrypt(node[3])
             if not url.startswith(("http://", "https://")):
             if not url.startswith(("http://", "https://")):
                 url = "https://" + url
                 url = "https://" + url
+            os.umask(22)
             tools.open_url(url)
             tools.open_url(url)
 
 
+            umask = self.config.get_value("Global", "umask")
+            if re.search(r'^\d{4}$', umask):
+                os.umask(int(umask))
+
     def do_exit(self, args):  # pragma: no cover
     def do_exit(self, args):  # pragma: no cover
         """close the text console"""
         """close the text console"""
         self._db.close()
         self._db.close()

+ 14 - 15
pwman/ui/cli.py

@@ -21,15 +21,6 @@
 from __future__ import print_function
 from __future__ import print_function
 import sys
 import sys
 import cmd
 import cmd
-import pwman
-from pwman.ui.baseui import BaseCommands
-from pwman import get_conf_options, get_db_version
-from pwman import parser_options
-from pwman.ui.tools import CLICallback
-import pwman.data.factory
-from pwman.exchange.importer import Importer
-from pwman.util.crypto_engine import CryptoEngine
-
 if sys.version_info.major > 2:
 if sys.version_info.major > 2:
     raw_input = input
     raw_input = input
 
 
@@ -39,6 +30,14 @@ try:
 except ImportError as e:  # pragma: no cover
 except ImportError as e:  # pragma: no cover
     _readline_available = False
     _readline_available = False
 
 
+from pwman.ui.baseui import BaseCommands
+from pwman import (get_conf_options, get_db_version, version, appname,
+                   parser_options, website)
+from pwman.ui.tools import CLICallback
+from pwman.data import factory
+from pwman.exchange.importer import Importer
+from pwman.util.crypto_engine import CryptoEngine
+
 
 
 class PwmanCli(cmd.Cmd, BaseCommands):
 class PwmanCli(cmd.Cmd, BaseCommands):
     """
     """
@@ -53,8 +52,8 @@ class PwmanCli(cmd.Cmd, BaseCommands):
         connecion, see if we have xsel ...
         connecion, see if we have xsel ...
         """
         """
         super(PwmanCli, self).__init__(**kwargs)
         super(PwmanCli, self).__init__(**kwargs)
-        self.intro = "%s %s (c) visit: %s" % (pwman.appname, pwman.version,
-                                              pwman.website)
+        self.intro = "%s %s (c) visit: %s" % (appname, version,
+                                              website)
         self._historyfile = config_parser.get_value("Readline", "history")
         self._historyfile = config_parser.get_value("Readline", "history")
         self.hasxsel = hasxsel
         self.hasxsel = hasxsel
         self.config = config_parser
         self.config = config_parser
@@ -76,7 +75,7 @@ class PwmanCli(cmd.Cmd, BaseCommands):
 
 
         self.prompt = "pwman> "
         self.prompt = "pwman> "
 
 
-		
+
 def get_ui_platform(platform):  # pragma: no cover
 def get_ui_platform(platform):  # pragma: no cover
     if 'darwin' in platform:
     if 'darwin' in platform:
         from pwman.ui.mac import PwmanCliMac as PwmanCli
         from pwman.ui.mac import PwmanCliMac as PwmanCli
@@ -100,8 +99,8 @@ def main():
     dbver = get_db_version(config, args)
     dbver = get_db_version(config, args)
     CryptoEngine.get()
     CryptoEngine.get()
 
 
-    
-    db = pwman.data.factory.createdb(dburi, dbver)
+
+    db = factory.createdb(dburi, dbver)
 
 
     if args.import_file:
     if args.import_file:
         importer = Importer((args, config, db))
         importer = Importer((args, config, db))
@@ -115,4 +114,4 @@ def main():
     except KeyboardInterrupt as e:
     except KeyboardInterrupt as e:
         print(e)
         print(e)
     finally:
     finally:
-        config.save()
+        config.save()

+ 4 - 4
pwman/ui/tools.py

@@ -116,15 +116,15 @@ def text_to_mcclipboard(text):  # pragma: no cover
         print (e, "\nExecuting pbcoy failed...")
         print (e, "\nExecuting pbcoy failed...")
 
 
 
 
-def open_url(link, macosx=False):  # pragma: no cover
+def open_url(link, macosx=False, ):  # pragma: no cover
     """
     """
     launch xdg-open or open in MacOSX with url
     launch xdg-open or open in MacOSX with url
     """
     """
-    uopen = "xdg-open"
+    uopen = "xdg-open "
     if macosx:
     if macosx:
-        uopen = "open"
+        uopen = "open "
     try:
     try:
-        sp.Popen([uopen, link], stdin=sp.PIPE)
+        sp.call(uopen+link, shell=True, stdout=sp.PIPE, stderr=sp.PIPE)
     except OSError as e:
     except OSError as e:
         print("Executing open_url failed with:\n", e)
         print("Executing open_url failed with:\n", e)
 
 

+ 0 - 74
scripts/pwman3

@@ -1,74 +0,0 @@
-#!/usr/bin/env python
-# ============================================================================
-# This file is part of Pwman3.
-#
-# Pwman3 is free software; you can redistribute it and/or modify
-# it under the terms of the GNU General Public License, version 2
-# as published by the Free Software Foundation;
-#
-# Pwman3 is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU General Public License
-# along with Pwman3; if not, write to the Free Software
-# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
-# ============================================================================
-# Copyright (C) 2012-2014 Oz Nahum Tiram <nahumoz@gmail.com>
-# ============================================================================
-# Copyright (C) 2006 Ivan Kelly <ivan@ivankelly.net>
-# ============================================================================
-from __future__ import print_function
-import sys
-from pwman import get_conf_options, get_db_version
-from pwman import parser_options
-from pwman.ui.tools import CLICallback
-import pwman.data.factory
-from pwman.exchange.importer import Importer
-from pwman.util.crypto_engine import CryptoEngine
-
-if sys.version_info.major > 2:
-    raw_input = input
-
-
-def get_ui_platform(platform):  # pragma: no cover
-    if 'darwin' in platform:
-        from pwman.ui.mac import PwmanCliMac as PwmanCli
-        OSX = True
-    elif 'win' in platform:
-        from pwman.ui.win import PwmanCliWin as PwmanCli
-        OSX = False
-    else:
-        from pwman.ui.cli import PwmanCli
-        OSX = False
-
-    return PwmanCli, OSX
-
-
-def main(args):
-    PwmanCli, OSX = get_ui_platform(sys.platform)
-    xselpath, dbtype, config = get_conf_options(args, OSX)
-    dbver = get_db_version(config, args)
-    CryptoEngine.get()
-
-    dburi = config.get_value('Database', 'dburi')
-    db = pwman.data.factory.createdb(dburi, dbver)
-
-    if args.import_file:
-        importer = Importer((args, config, db))
-        importer.run()
-        sys.exit(0)
-
-    cli = PwmanCli(db, xselpath, CLICallback, config)
-
-    try:
-        cli.cmdloop()
-    except KeyboardInterrupt as e:
-        print(e)
-    finally:
-        config.save()
-
-if __name__ == '__main__':
-    args = parser_options().parse_args()
-    main(args)

+ 7 - 16
setup.py

@@ -13,7 +13,6 @@ from setuptools import find_packages
 import sys
 import sys
 from setuptools.command.install import install
 from setuptools.command.install import install
 import os
 import os
-from subprocess import Popen, PIPE
 import pwman
 import pwman
 
 
 # The BuildManPage code is distributed
 # The BuildManPage code is distributed
@@ -296,15 +295,6 @@ class ManPageCreator(object):
 sys.path.insert(0, os.getcwd())
 sys.path.insert(0, os.getcwd())
 
 
 
 
-def describe():
-    des = Popen('git describe', shell=True, stdout=PIPE)
-    ver = des.stdout.readlines()
-    if ver:
-        return ver[0].decode().strip()
-    else:
-        return pwman.version
-
-
 class PyCryptoInstallCommand(install):
 class PyCryptoInstallCommand(install):
 
 
     """
     """
@@ -329,7 +319,7 @@ class PyCryptoInstallCommand(install):
 
 
 
 
 setup(name=pwman.appname,
 setup(name=pwman.appname,
-      version=describe(),
+      version=pwman.version,
       description=pwman.description,
       description=pwman.description,
       long_description=pwman.long_description,
       long_description=pwman.long_description,
       author=pwman.author,
       author=pwman.author,
@@ -345,10 +335,11 @@ setup(name=pwman.appname,
                    'Intended Audience :: End Users/Desktop',
                    'Intended Audience :: End Users/Desktop',
                    'Intended Audience :: Developers',
                    'Intended Audience :: Developers',
                    'Intended Audience :: System Administrators',
                    'Intended Audience :: System Administrators',
-                   'License :: OSI Approved :: GNU General Public License v3 or later (GPLv3+)',
+                   ('License :: OSI Approved :: GNU General Public License'
+                    ' v3 or later (GPLv3+)'),
                    'Operating System :: OS Independent',
                    'Operating System :: OS Independent',
                    'Programming Language :: Python',
                    'Programming Language :: Python',
-                   'Programming Language :: Python :: 2.7'
+                   'Programming Language :: Python :: 2.7',
                    'Programming Language :: Python :: 3',
                    'Programming Language :: Python :: 3',
                    'Programming Language :: Python :: 3.2',
                    'Programming Language :: Python :: 3.2',
                    'Programming Language :: Python :: 3.3',
                    'Programming Language :: Python :: 3.3',
@@ -359,7 +350,7 @@ setup(name=pwman.appname,
           'install_pycrypto': PyCryptoInstallCommand,
           'install_pycrypto': PyCryptoInstallCommand,
           'build_manpage': BuildManPage
           'build_manpage': BuildManPage
       },
       },
-	  entry_points={
-	  'console_scripts': [ 'pwman-cli = pwman.ui.cli:main' ]
-		}
+      entry_points={
+          'console_scripts': ['pwman3 = pwman.ui.cli:main']
+          }
       )
       )

+ 2 - 0
test_requirements.txt

@@ -1,3 +1,5 @@
 psycopg2
 psycopg2
 pymysql
 pymysql
+pymongo==2.8
 pexpect
 pexpect
+coverage

+ 20 - 17
tests/test_complete_ui.py

@@ -20,22 +20,26 @@
 from __future__ import print_function
 from __future__ import print_function
 import pexpect
 import pexpect
 import unittest
 import unittest
+import sys
 import os
 import os
-import shutil
-
+from pwman import which
 
 
 class Ferrum(unittest.TestCase):
 class Ferrum(unittest.TestCase):
     def clean_files(self):
     def clean_files(self):
-        lfile = 'convert-test.log'
-        with open(lfile) as l:
-            lines = l.readlines()
-            orig = lines[0].split(':')[-1].strip()
-            backup = lines[1].split()[-1].strip()
-        shutil.copy(backup, orig)
+        #lfile = 'convert-test.log'
+        #with open(lfile) as l:
+        #    lines = l.readlines()
+        #    orig = lines[0].split(':')[-1].strip()
+        #    backup = lines[1].split()[-1].strip()
+        #shutil.copy(backup, orig)
         # do some cleaning
         # do some cleaning
-        os.remove(lfile)
-        os.remove('test-chg_passwd.log')
-        os.remove(backup)
+        # os.remove(lfile)
+        if os.path.exists('test-chg_passwd.log'):
+            os.remove('test-chg_passwd.log')
+        #os.remove(backup)
+        db = os.path.join(os.path.dirname(__file__), 'foo.baz.db')
+        if os.path.exists(db):
+            os.remove(db)
 
 
     @unittest.skip("obsolete")
     @unittest.skip("obsolete")
     def test_b_run_convert(self):
     def test_b_run_convert(self):
@@ -56,12 +60,11 @@ class Ferrum(unittest.TestCase):
     def test_c_change_pass(self):
     def test_c_change_pass(self):
         lfile = 'test-chg_passwd.log'
         lfile = 'test-chg_passwd.log'
         logfile = open(lfile, 'wb')
         logfile = open(lfile, 'wb')
-        child = pexpect.spawn(os.path.join(os.path.dirname(__file__),
-                                           '../scripts/pwman3') +
-                              ' -d ', logfile=logfile)
-        child.sendline('passwd')
-        child.expect("Please enter your current password:")
-        child.sendline('12345')
+        cmd = which('pwman3')
+        db = 'sqlite://' + os.path.join(os.path.dirname(__file__), 'foo.baz.db')
+        child = pexpect.spawn(cmd + ' -d ' + db, logfile=logfile)
+        if sys.version_info[0] > 2:
+            child.expect('[\s|\S]+(password:)$', timeout=10)
         child.sendline('foobar')
         child.sendline('foobar')
         child.sendline('foobar')
         child.sendline('foobar')
         self.clean_files()
         self.clean_files()

+ 1 - 1
tests/test_crypto_engine.py

@@ -59,7 +59,7 @@ class TestPassGenerator(unittest.TestCase):
     def test_len(self):
     def test_len(self):
         self.assertEqual(13, len(generate_password(pass_len=13)))
         self.assertEqual(13, len(generate_password(pass_len=13)))
 
 
-    def test_has_upper(self):
+    def test_has_no_lower(self):
         password = generate_password(uppercase=True, lowercase=False)
         password = generate_password(uppercase=True, lowercase=False)
         lower = set(string.ascii_lowercase)
         lower = set(string.ascii_lowercase)
         it = lower.intersection(set(password))
         it = lower.intersection(set(password))

+ 3 - 2
tests/test_factory.py

@@ -47,8 +47,9 @@ class TestFactory(unittest.TestCase):
         self.tester.create()
         self.tester.create()
 
 
     def test_factory_check_db_ver(self):
     def test_factory_check_db_ver(self):
-        self.assertEqual(factory.check_db_version('sqlite://'+testdb), 0.6)
+        self.assertEqual(factory.check_db_version('sqlite://'+testdb), u"'0.6'")
 
 
+    @unittest.skip("not supported at the moment")
     def test_factory_check_db_file(self):
     def test_factory_check_db_file(self):
         fn = os.path.join(os.path.dirname(__file__), 'baz.db')
         fn = os.path.join(os.path.dirname(__file__), 'baz.db')
         db = factory.createdb('sqlite:///'+os.path.abspath(fn), 0.3)
         db = factory.createdb('sqlite:///'+os.path.abspath(fn), 0.3)
@@ -65,7 +66,7 @@ class TestFactory(unittest.TestCase):
         os.unlink(fn)
         os.unlink(fn)
         self.assertIsInstance(db, SQLite)
         self.assertIsInstance(db, SQLite)
         self.assertRaises(DatabaseException, factory.createdb, *('UNKNOWN',
         self.assertRaises(DatabaseException, factory.createdb, *('UNKNOWN',
-                                                                 0.6))
+                                                                 __DB_FORMAT__))
 
 
     def test_factory_createdb(self):
     def test_factory_createdb(self):
         db = factory.createdb("sqlite:///test.db", 0.6)
         db = factory.createdb("sqlite:///test.db", 0.6)

+ 2 - 2
tests/test_init.py

@@ -74,9 +74,9 @@ class TestInit(unittest.TestCase):
 
 
     def test_get_db_version(self):
     def test_get_db_version(self):
         v = get_db_version(self.tester.configp, 'sqlite')
         v = get_db_version(self.tester.configp, 'sqlite')
-        self.assertEqual(v, __DB_FORMAT__)
+        self.assertEqual(v, u"'0.6'")
         v = get_db_version(self.tester.configp, 'sqlite')
         v = get_db_version(self.tester.configp, 'sqlite')
-        self.assertEqual(v, 0.6)
+        self.assertEqual(v, u"'0.6'")
         os.unlink(testdb)
         os.unlink(testdb)
 
 
     def test_set_xsel(self):
     def test_set_xsel(self):

+ 142 - 0
tests/test_mongodb.py

@@ -0,0 +1,142 @@
+# ============================================================================
+# This file is part of Pwman3.
+#
+# Pwman3 is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License, version 2
+# as published by the Free Software Foundation;
+#
+# Pwman3 is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Pwman3; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
+# ============================================================================
+# Copyright (C) 2015 Oz Nahum Tiram <nahumoz@gmail.com>
+# ============================================================================
+
+import unittest
+import sys
+if sys.version_info.major > 2:  # pragma: no cover
+    from urllib.parse import urlparse
+else:  # pragma: no cover
+    from urlparse import urlparse
+import pymongo
+from .test_crypto_engine import give_key, DummyCallback
+from pwman.util.crypto_engine import CryptoEngine
+from pwman.data.drivers.mongodb import MongoDB
+from pwman.data.nodes import Node
+# use pwmantest
+
+# db.createUser(
+#    {
+#      user: "tester",
+#      pwd: "12345678",
+#       roles: [{ role: "dbAdmin", db: "pwmantest" },
+#               { role: "readWrite", db: "pwmantest" },]
+#    })
+
+
+class TestMongoDB(unittest.TestCase):
+
+    @classmethod
+    def setUpClass(cls):
+        u = u"mongodb://tester:12345678@localhost:27017/pwmantest"
+        cls.db = MongoDB(urlparse(u))
+        cls.db._open()
+
+    @classmethod
+    def tearDownClass(cls):
+        coll = cls.db._db['crypto']
+        coll.drop()
+        cls.db._db['counters'].drop()
+        cls.db._db['nodes'].drop()
+        cls.db.close()
+
+    def test_1_con(self):
+        self.assertIsInstance(self.db._con, pymongo.Connection)
+
+    @unittest.skip("MongoDB creates collections on the fly")
+    def test_2_create_collections(self):
+        pass
+
+    def test_3a_load_key(self):
+        secretkey = self.db.loadkey()
+        self.assertIsNone(secretkey)
+
+    def test_3b_load_key(self):
+        self.db.savekey('SECRET$6$KEY')
+        secretkey = self.db.loadkey()
+        self.assertEqual(secretkey, u'SECRET$6$KEY')
+
+    @unittest.skip("")
+    def test_4_save_crypto(self):
+        self.db.save_crypto_info("TOP", "SECRET")
+        secretkey = self.db.loadkey()
+        self.assertEqual(secretkey, 'TOP$6$SECRET')
+        row = self.db.fetch_crypto_info()
+        self.assertEqual(row, ('TOP', 'SECRET'))
+
+    def test_5_add_node(self):
+        innode = [u"TBONE", u"S3K43T", u"example.org", u"some note",
+                  [u"bartag", u"footag"]]
+
+        kwargs = {
+            "username":innode[0], "password": innode[1],
+            "url": innode[2], "notes": innode[3], "tags": innode[4]
+        }
+
+        node = Node(clear_text=True, **kwargs)
+        self.db.add_node(node)
+        outnode = self.db.getnodes([1])[0]
+        no = outnode[1:5]
+        no.append(outnode[5:])
+        o = Node.from_encrypted_entries(*no)
+        self.assertEqual(list(node), list(o))
+
+    def test_6_list_nodes(self):
+        ret = self.db.listnodes()
+        self.assertEqual(ret, [1])
+        ce = CryptoEngine.get()
+        fltr = ce.encrypt("footag")
+        ret = self.db.listnodes(fltr)
+        self.assertEqual(ret, [1])
+
+    def test_6a_list_tags(self):
+        ret = self.db.listtags()
+        ce = CryptoEngine.get()
+        ec_tags = map(ce.encrypt,[u'bartag', u'footag'])
+        for t in ec_tags:
+            self.assertIn(t, ret)
+
+    def test_6b_get_nodes(self):
+        ret = self.db.getnodes([1])
+        retb = self.db.getnodes([])
+        self.assertListEqual(ret, retb)
+
+    @unittest.skip("tags are created in situ in mongodb")
+    def test_7_get_or_create_tag(self):
+        pass
+
+    @unittest.skip("tags are removed with their node")
+    def test_7a_clean_orphans(self):
+        pass
+
+    def test_8_remove_node(self):
+        self.db.removenodes([1])
+        n = self.db.listnodes()
+        self.assertEqual(len(n), 0)
+
+    @unittest.skip("No schema migration with mongodb")
+    def test_9_check_db_version(self):
+        pass
+
+
+if __name__ == '__main__':
+
+    ce = CryptoEngine.get()
+    ce.callback = DummyCallback()
+    ce.changepassword(reader=give_key)
+    unittest.main(verbosity=2, failfast=True)

+ 1 - 1
tests/test_mysql.py

@@ -117,7 +117,7 @@ class TestMySQLDatabase(unittest.TestCase):
         self.db._cur.execute("DROP TABLE DBVERSION")
         self.db._cur.execute("DROP TABLE DBVERSION")
         self.db._con.commit()
         self.db._con.commit()
         v = self.db.check_db_version(urlparse(dburi))
         v = self.db.check_db_version(urlparse(dburi))
-        self.assertEqual(v, None)
+        self.assertEqual(v, '0.6')
         self.db._cur.execute("CREATE TABLE DBVERSION("
         self.db._cur.execute("CREATE TABLE DBVERSION("
                              "VERSION TEXT NOT NULL) ")
                              "VERSION TEXT NOT NULL) ")
         self.db._con.commit()
         self.db._con.commit()

+ 4 - 4
tests/test_postgresql.py

@@ -30,8 +30,8 @@ from pwman.util.crypto_engine import CryptoEngine
 # testing on linux host
 # testing on linux host
 # su - postgres
 # su - postgres
 # psql
 # psql
-# postgres=# create user $YOUR_USERNAME;
-# postgres=# grant ALL ON DATABASE pwman to $YOUR_USERNAME;
+# postgres=# CREATE USER tester WITH PASSWORD '123456';
+# postgres=# grant ALL ON DATABASE pwman to tester;
 #
 #
 ##
 ##
 
 
@@ -120,11 +120,11 @@ class TestPostGresql(unittest.TestCase):
 
 
         dburi = "postgresql://tester:123456@localhost/pwman"
         dburi = "postgresql://tester:123456@localhost/pwman"
         v = self.db.check_db_version(dburi)
         v = self.db.check_db_version(dburi)
-        self.assertEqual(v, '0.6')
+        self.assertEqual(str(v), '0.6')
         self.db._cur.execute("DROP TABLE DBVERSION")
         self.db._cur.execute("DROP TABLE DBVERSION")
         self.db._con.commit()
         self.db._con.commit()
         v = self.db.check_db_version(dburi)
         v = self.db.check_db_version(dburi)
-        self.assertEqual(v, None)
+        self.assertEqual(str(v), '0.6')
         self.db._cur.execute("CREATE TABLE DBVERSION("
         self.db._cur.execute("CREATE TABLE DBVERSION("
                              "VERSION TEXT NOT NULL DEFAULT {}"
                              "VERSION TEXT NOT NULL DEFAULT {}"
                              ")".format('0.6'))
                              ")".format('0.6'))

+ 7 - 3
tests/test_pwman.py

@@ -21,12 +21,12 @@
 import os
 import os
 import sys
 import sys
 import unittest
 import unittest
-#from .test_tools import (SetupTester)
 from .test_crypto_engine import CryptoEngineTest, TestPassGenerator
 from .test_crypto_engine import CryptoEngineTest, TestPassGenerator
 from .test_config import TestConfig
 from .test_config import TestConfig
 from .test_sqlite import TestSQLite
 from .test_sqlite import TestSQLite
 from .test_postgresql import TestPostGresql
 from .test_postgresql import TestPostGresql
 from .test_mysql import TestMySQLDatabase
 from .test_mysql import TestMySQLDatabase
+from .test_mongodb import TestMongoDB
 from .test_importer import TestImporter
 from .test_importer import TestImporter
 from .test_factory import TestFactory
 from .test_factory import TestFactory
 from .test_base_ui import TestBaseUI
 from .test_base_ui import TestBaseUI
@@ -54,11 +54,15 @@ def suite():
     suite.addTest(loader.loadTestsFromTestCase(TestSQLite))
     suite.addTest(loader.loadTestsFromTestCase(TestSQLite))
     suite.addTest(loader.loadTestsFromTestCase(TestPostGresql))
     suite.addTest(loader.loadTestsFromTestCase(TestPostGresql))
     suite.addTest(loader.loadTestsFromTestCase(TestMySQLDatabase))
     suite.addTest(loader.loadTestsFromTestCase(TestMySQLDatabase))
+    suite.addTest(loader.loadTestsFromTestCase(TestMongoDB))
     suite.addTest(loader.loadTestsFromTestCase(TestImporter))
     suite.addTest(loader.loadTestsFromTestCase(TestImporter))
     suite.addTest(loader.loadTestsFromTestCase(TestFactory))
     suite.addTest(loader.loadTestsFromTestCase(TestFactory))
     suite.addTest(loader.loadTestsFromTestCase(TestBaseUI))
     suite.addTest(loader.loadTestsFromTestCase(TestBaseUI))
     suite.addTest(loader.loadTestsFromTestCase(TestInit))
     suite.addTest(loader.loadTestsFromTestCase(TestInit))
     suite.addTest(loader.loadTestsFromTestCase(TestNode))
     suite.addTest(loader.loadTestsFromTestCase(TestNode))
-    #if 'win' not in sys.platform:
-    #    suite.addTest(loader.loadTestsFromTestCase(Ferrum))
+    if 'win' not in sys.platform:
+        suite.addTest(loader.loadTestsFromTestCase(Ferrum))
     return suite
     return suite
+
+if __name__ == '__main__':
+    unittest.main(verbosity=2, failfast=True)

+ 2 - 1
tests/test_sqlite.py

@@ -151,7 +151,8 @@ class TestSQLite(unittest.TestCase):
         self.assertEqual(4, len(list(tags)))
         self.assertEqual(4, len(list(tags)))
 
 
     def test_a11_test_rmnodes(self):
     def test_a11_test_rmnodes(self):
-        self.db.removenodes([1, 2])
+        for n in [1, 2]:
+            self.db.removenodes([n])
         rv = self.db._cur.execute("select * from node").fetchall()
         rv = self.db._cur.execute("select * from node").fetchall()
         self.assertListEqual(rv, [])
         self.assertListEqual(rv, [])
 
 

+ 38 - 9
tox.ini

@@ -4,12 +4,41 @@
 # and then run "tox" from this directory.
 # and then run "tox" from this directory.
 
 
 [tox]
 [tox]
-envlist = py27, py34
-
-[testenv]
-commands = {envpython} setup.py test
-changedir = .
-deps = pexpect
-       pycrypto
-       colorama
-sitepackages=True
+envlist = py27,py34
+
+[testenv:py27]
+commands = coverage erase
+       {envbindir}/python setup.py develop
+       coverage run -p setup.py test
+       coverage combine
+
+deps = -rrequirements.txt 
+        pymongo==2.8
+        pymysql
+        psycopg2
+        pexpect
+        coverage
+
+[testenv:py34]
+commands = coverage3 erase
+       {envbindir}/python setup.py develop
+       coverage3 run -p setup.py test
+       coverage combine
+
+deps = -rrequirements.txt 
+        pymongo==2.8
+        pymysql
+        psycopg2
+        pexpect
+        coverage
+
+[testenv:docs]
+changedir = docs
+deps = -rrequirements.txt 
+        pymongo==2.8
+        pymysql
+        psycopg2
+        pexpect
+        sphinx
+commands =
+  sphinx-build -b html -d {envtmpdir}/doctrees source {envtmpdir}/html