# ============================================================================ # This file is part of Pwman3. # # Pwman3 is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License, version 2 # as published by the Free Software Foundation; # # Pwman3 is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with Pwman3; if not, write to the Free Software # Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA # ============================================================================ # Copyright (C) 2012 Oz Nahum # ============================================================================ # Copyright (C) 2006 Ivan Kelly # ============================================================================ """Postgresql Database implementation.""" import sys if sys.version_info.major > 2: # pragma: no cover from urllib import urlparse else: from urlparse import urlparse import psycopg2 as pg import cPickle import pwman.util.config as config from pwman.data.database import Database, DatabaseException, __DB_FORMAT__ class PostgresqlDatabase(Database): """ Postgresql Database implementation This assumes that your database admin has created a pwman database for you and shared the user name and password with you. This driver send no clear text on wire. ONLY excrypted stuff is sent between the client and the server. Encryption and decryption are happening on your localhost, not on the Postgresql server. create table version CREATE TABLE DB_VERSION(DBVERSION TEXT NOT NULL DEFAULT ''); Check if db_version exists SELECT 1 FROM information_schema.tables where table_name = 'db_version'; get information: SELECT dbversion FROM DB_VERSION; """ @classmethod def check_db_version(cls, user, dbname='pwman'): """ Check the database version """ con = pg.connect("dbname=pwman user=%s" % user) cur = con.cursor() try: cur.execute("SELECT VERSION from DBVERSION") version = cur.fetchone() return version except pg.ProgrammingError: con.rollback() raise DatabaseException("Something seems fishy with the DB") def __init__(self, pgsqluri, dbformat=__DB_FORMAT__): """ Initialise PostgresqlDatabase instance. """ self._pgsqluri = pgsqluri def _open(self): u = urlparse(self._pgsqluri) self._con = pg.connect(database=u.path[1:], user=u.username, password=u.password, host=u.hostname) self._cur = self._con.cursor() self._cur.execute("CREATE TABLE DBVERSION(VERSION TEXT NOT NULL " "DEFAULT '0.6')") def _get_cur(self): try: if self._con is not None: return self._con.cursor() except pgdb.DatabaseError: pass server = "%s:%s" % (self._server, self._port) self._con = pgdb.connect(host=server, database=self._database, user=self._user, password=self._password) self._cur = self._con.cursor() return self._cur def close(self): self._cur.close() self._con.close() def listtags(self, all=False): sql = '' params = {} i = 0 if len(self._filtertags) == 0 or all: sql = "SELECT DATA FROM %sTAGS ORDER BY DATA ASC" % (self._prefix) else: sql = ("SELECT %sTAGS.DATA FROM %sLOOKUP" + " INNER JOIN %sTAGS ON %sLOOKUP.TAG = %sTAGS.ID" + " WHERE NODE IN (") % (self._prefix, self._prefix, self._prefix, self._prefix, self._prefix) first = True i += 1 paramname = "tag%d" % (i) for t in self._filtertags: if not first: sql += " INTERSECT " else: first = False sql += ("SELECT NODE FROM %sLOOKUP LEFT JOIN %sTAGS ON TAG " "= %sTAGS.ID " " WHERE %sTAGS.DATA = %%(%s)s" % (self._prefix, self._prefix, self._prefix, self._prefix, paramname)) params[paramname] = cPickle.dumps(t) sql += ") EXCEPT SELECT DATA FROM %sTAGS WHERE " % self._prefix first = True for t in self._filtertags: if not first: sql += " OR " else: first = False sql += "%sTAGS.DATA = %%(%s)s" % (self._prefix, paramname) try: cursor = self._get_cur() cursor.execute(sql, params) tags = [] row = cursor.fetchone() while row: tag = cPickle.loads(str(row[0])) tags.append(tag) row = cursor.fetchone() return tags except pgdb.DatabaseError as e: raise DatabaseException("Postgresql: %s" % (e)) def getnodes(self, ids): nodes = [] idstr = "" first = True if len(ids) == 0: idstr = "-1" for i in ids: if first: idstr += "%d" % (i) first = False else: idstr += ", %d" % (i) sql = ("SELECT ID,DATA FROM %sNODES WHERE ID IN (%s)" "" % (self._prefix, pgdb.escape_string(idstr))) try: cursor = self._get_cur() cursor.execute(sql) row = cursor.fetchone() while row: node = cPickle.loads(str(row[1])) node.set_id(row[0]) nodes.append(node) row = cursor.fetchone() except pgdb.DatabaseError as e: raise DatabaseException("Postgresql: %s" % (e)) return nodes def editnode(self, id, node): if not isinstance(node, Node): raise DatabaseException("Tried to insert foreign object into " "database [%s]" % node) try: cursor = self._get_cur() sql = ("UPDATE %sNODES SET DATA = %%(data)s WHERE ID = %%(id)d" "" % (self._prefix)) cursor.execute(sql, {"data": cPickle.dumps(node), "id": id}) except pgdb.DatabaseError as e: raise DatabaseException("Postgresql: %s" % (e)) self._setnodetags(node) self._checktags() self._commit() def addnodes(self, nodes): cursor = self._get_cur() for n in nodes: sql = "INSERT INTO %sNODES(DATA) VALUES(%%(data)s)" % (self._prefix) if not isinstance(n, Node): raise DatabaseException("Tried to insert foreign object into " "database [%s]", n) values = {"data": cPickle.dumps(n)} try: cursor.execute(sql, values) except pgdb.DatabaseError as e: raise DatabaseException("Postgresql: %s" % (e)) id = self._lastrowid("NODES") n.set_id(id) self._setnodetags(n) self._commit() def removenodes(self, nodes): cursor = self._get_cur() for n in nodes: if not isinstance(n, Node): raise DatabaseException("Tried to delete foreign object" "from database [%s]", n) try: sql = "DELETE FROM %sNODES WHERE ID = %%(id)d" % (self._prefix) cursor.execute(sql, {"id": n.get_id()}) except pgdb.DatabaseError as e: raise DatabaseException("Postgresql: %s" % (e)) self._deletenodetags(n) self._checktags() self._commit() def listnodes(self): sql = '' params = {} i = 0 cursor = self._get_cur() if len(self._filtertags) == 0: sql = "SELECT ID FROM %sNODES ORDER BY ID ASC" % (self._prefix) else: first = True for t in self._filtertags: if not first: sql += " INTERSECT " else: first = False i += 1 paramname = "tag%d" % (i) sql += (("SELECT NODE FROM %sLOOKUP LEFT JOIN %sTAGS ON" " TAG = %sTAGS.ID" " WHERE %sTAGS.DATA = %%(%s)s ") % (self._prefix, self._prefix, self._prefix, self._prefix, paramname)) params[paramname] = cPickle.dumps(t) try: cursor.execute(sql, params) ids = [] row = cursor.fetchone() while row: ids.append(row[0]) row = cursor.fetchone() return ids except pgdb.DatabaseError as e: raise DatabaseException("Postgresql: %s" % (e)) def _commit(self): try: self._con.commit() except pgdb.DatabaseError as e: self._con.rollback() raise DatabaseException( "Postgresql: Error commiting data to db [%s]" % (e)) def _tagids(self, tags): ids = [] cursor = self._get_cur() for t in tags: pickled = cPickle.dumps(t) try: ids.append(self._tagidcache[pickled]) continue except KeyError as e: pass # not in cache sql = "SELECT ID FROM %sTAGS WHERE DATA = %%(tag)s" % (self._prefix) if not isinstance(t, Tag): raise DatabaseException("Tried to insert foreign object" " into database [%s]", t) data = {"tag": pickled} try: cursor.execute(sql, data) row = cursor.fetchone() if row: ids.append(row[0]) self._tagidcache[pickled] = row[0] else: sql = "INSERT INTO %sTAGS(DATA) VALUES(%%(tag)s)" % ( self._prefix) cursor.execute(sql, data) id = self._lastrowid("TAGS") ids.append(id) self._tagidcache[pickled] = id except pgdb.DatabaseError as e: raise DatabaseException("Postgresql: %s" % (e)) return ids def _deletenodetags(self, node): try: cursor = self._get_cur() sql = "DELETE FROM %sLOOKUP WHERE NODE = %%(node)d" % (self._prefix) cursor.execute(sql, {"node": node.get_id()}) except pgdb.DatabaseError as e: raise DatabaseException("Postgresql: %s" % (e)) def _setnodetags(self, node): self._deletenodetags(node) ids = self._tagids(node.get_tags()) for i in ids: sql = "INSERT INTO %sLOOKUP VALUES(%%(node)d, %%(tag)d)" % ( self._prefix) params = {"node": node.get_id(), "tag": i} try: cursor = self._get_cur() cursor.execute(sql, params) except pgdb.DatabaseError as e: raise DatabaseException("Postgresql: %s" % (e)) def _checktags(self): self._tagidcache.clear() try: cursor = self._get_cur() sql = ("DELETE FROM %sTAGS WHERE ID NOT IN " + "(SELECT TAG FROM %sLOOKUP GROUP BY TAG)") % (self._prefix, self._prefix) cursor.execute(sql) except pgdb.DatabaseError as e: raise DatabaseException("Postgresql: %s" % (e)) self._commit() def _lastrowid(self, name): cursor = self._get_cur() cursor.execute("SELECT LAST_VALUE FROM %s%s_ID_SEQ" % (self._prefix, name)) row = cursor.fetchone() if not row: return 0 else: return row[0] def _checktables(self): """ Check if the Pwman tables exist """ cursor = self._get_cur() cursor.execute("SELECT 1 FROM INFORMATION_SCHEMA.TABLES WHERE " "TABLE_NAME = '%snodes'" % (self._prefix)) if not cursor.fetchone(): # table doesn't exist, create it cursor.execute(("CREATE TABLE %sNODES " "(ID SERIAL PRIMARY KEY, DATA TEXT NOT NULL)" % (self._prefix))) cursor.execute(("CREATE TABLE %sTAGS" "(ID SERIAL PRIMARY KEY," "DATA TEXT NOT NULL UNIQUE)") % (self._prefix)) cursor.execute(("CREATE TABLE %sLOOKUP" "(NODE INTEGER NOT NULL, TAG INTEGER NOT NULL," " PRIMARY KEY(NODE, TAG))") % self._prefix) cursor.execute(("CREATE TABLE %sKEY" + "(THEKEY TEXT NOT NULL DEFAULT '')") % (self._prefix)) cursor.execute("INSERT INTO %sKEY VALUES('')" % (self._prefix)) try: self._con.commit() except pgdb.DatabaseError as e: self._con.rollback() raise e def savekey(self, key): sql = "UPDATE %sKEY SET THEKEY = %%(key)s" % (self._prefix) values = {"key": key} cursor = self._get_cur() cursor.execute(sql, values) try: self._con.commit() except pgdb.DatabaseError as e: self._con.rollback() raise DatabaseException( "Postgresql: Error saving key [%s]" % (e)) def loadkey(self): cursor = self._get_cur() cursor.execute("SELECT THEKEY FROM %sKEY" % (self._prefix)) keyrow = cursor.fetchone() if (keyrow[0] == ''): return None else: return keyrow[0]