mysql.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. # ============================================================================
  2. # This file is part of Pwman3.
  3. #
  4. # Pwman3 is free software; you can redistribute it and/or modify
  5. # it under the terms of the GNU General Public License, version 2
  6. # as published by the Free Software Foundation;
  7. #
  8. # Pwman3 is distributed in the hope that it will be useful,
  9. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  10. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  11. # GNU General Public License for more details.
  12. #
  13. # You should have received a copy of the GNU General Public License
  14. # along with Pwman3; if not, write to the Free Software
  15. # Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
  16. # ============================================================================
  17. # Copyright (C) 2012-2015 Oz Nahum <nahumoz@gmail.com>
  18. # ============================================================================
  19. #mysql -u root -p
  20. #create database pwmantest
  21. #create user 'pwman'@'localhost' IDENTIFIED BY '123456';
  22. #grant all on pwmantest.* to 'pwman'@'localhost';
  23. """MySQL Database implementation."""
  24. from __future__ import print_function
  25. from pwman.data.database import Database, __DB_FORMAT__
  26. import sys
  27. import pymysql as mysql
  28. mysql.install_as_MySQLdb()
  29. #else:
  30. # import MySQLdb as mysql
  31. class MySQLDatabase(Database):
  32. @classmethod
  33. def check_db_version(cls, dburi):
  34. port = 3306
  35. credentials, host = dburi.netloc.split('@')
  36. user, passwd = credentials.split(':')
  37. if ':' in host:
  38. host, port = host.split(':')
  39. port = int(port)
  40. con = mysql.connect(host=host, port=port, user=user, passwd=passwd,
  41. db=dburi.path.lstrip('/'))
  42. cur = con.cursor()
  43. try:
  44. cur.execute("SELECT VERSION FROM DBVERSION")
  45. version = cur.fetchone()
  46. cur.close()
  47. con.close()
  48. return version[-1]
  49. except mysql.ProgrammingError:
  50. con.rollback()
  51. def __init__(self, mysqluri, dbformat=__DB_FORMAT__):
  52. self.dburi = mysqluri
  53. self.dbversion = dbformat
  54. def _open(self):
  55. port = 3306
  56. credentials, host = self.dburi.netloc.split('@')
  57. user, passwd = credentials.split(':')
  58. if ':' in host:
  59. host, port = host.split(':')
  60. port = int(port)
  61. self._con = mysql.connect(host=host, port=port, user=user,
  62. passwd=passwd,
  63. db=self.dburi.path.lstrip('/'))
  64. self._cur = self._con.cursor()
  65. self._create_tables()
  66. def _create_tables(self):
  67. try:
  68. self._cur.execute("SELECT 1 from DBVERSION")
  69. version = self._cur.fetchone()
  70. if version:
  71. return
  72. except mysql.ProgrammingError:
  73. self._con.rollback()
  74. try:
  75. self._cur.execute("CREATE TABLE NODE(ID SERIAL PRIMARY KEY, "
  76. "USERNAME TEXT NOT NULL, "
  77. "PASSWORD TEXT NOT NULL, "
  78. "URL TEXT NOT NULL, "
  79. "NOTES TEXT NOT NULL"
  80. ")")
  81. self._cur.execute("CREATE TABLE TAG"
  82. "(ID SERIAL PRIMARY KEY,"
  83. "DATA VARCHAR(255) NOT NULL UNIQUE)")
  84. self._cur.execute("CREATE TABLE LOOKUP ("
  85. "nodeid INTEGER NOT NULL REFERENCES NODE(ID),"
  86. "tagid INTEGER NOT NULL REFERENCES TAG(ID)"
  87. ")")
  88. self._cur.execute("CREATE TABLE CRYPTO "
  89. "(SEED TEXT, DIGEST TEXT)")
  90. self._cur.execute("CREATE TABLE DBVERSION("
  91. "VERSION TEXT NOT NULL "
  92. ")")
  93. self._cur.execute("INSERT INTO DBVERSION VALUES(%s)",
  94. (self.dbversion,))
  95. self._con.commit()
  96. except mysql.ProgrammingError: # pragma: no cover
  97. self._con.rollback()
  98. def getnodes(self, ids):
  99. if ids:
  100. sql = ("SELECT * FROM NODE WHERE ID IN ({})"
  101. "".format(','.join('%s' for i in ids)))
  102. else:
  103. sql = "SELECT * FROM NODE"
  104. self._cur.execute(sql, (ids))
  105. nodes = self._cur.fetchall()
  106. nodes_w_tags = []
  107. for node in nodes:
  108. tags = list(self._get_node_tags(node))
  109. nodes_w_tags.append(list(node) + tags)
  110. return nodes_w_tags
  111. def add_node(self, node):
  112. sql = ("INSERT INTO NODE(USERNAME, PASSWORD, URL, NOTES)"
  113. "VALUES(%s, %s, %s, %s)")
  114. node_tags = list(node)
  115. node, tags = node_tags[:4], node_tags[-1]
  116. self._cur.execute(sql, (node))
  117. nid = self._cur.lastrowid
  118. self._setnodetags(nid, tags)
  119. self._con.commit()
  120. def _get_node_tags(self, node):
  121. sql = "SELECT tagid FROM LOOKUP WHERE NODEID = %s"
  122. self._cur.execute(sql, (str(node[0]),))
  123. tagids = self._cur.fetchall()
  124. if tagids:
  125. sql = ("SELECT DATA FROM TAG WHERE ID IN (%s)"
  126. "" % ','.join(['%s']*len(tagids)))
  127. tagids = [str(id[0]) for id in tagids]
  128. self._cur.execute(sql, (tagids))
  129. tags = self._cur.fetchall()
  130. for t in tags:
  131. yield t[0]
  132. def _setnodetags(self, nodeid, tags):
  133. for tag in tags:
  134. tid = self._get_or_create_tag(tag)
  135. self._update_tag_lookup(nodeid, tid)
  136. def _get_tag(self, tagcipher):
  137. sql_search = "SELECT ID FROM TAG WHERE DATA = %s"
  138. self._cur.execute(sql_search, ([tagcipher]))
  139. rv = self._cur.fetchone()
  140. return rv
  141. def _get_or_create_tag(self, tagcipher):
  142. rv = self._get_tag(tagcipher)
  143. if rv:
  144. return rv[0]
  145. else:
  146. sql_insert = "INSERT INTO TAG(DATA) VALUES(%s)"
  147. self._cur.execute(sql_insert, ([tagcipher]))
  148. return self._cur.lastrowid
  149. def _update_tag_lookup(self, nodeid, tid):
  150. sql_lookup = "INSERT INTO LOOKUP(nodeid, tagid) VALUES(%s, %s)"
  151. self._cur.execute(sql_lookup, (nodeid, tid))
  152. self._con.commit()
  153. def fetch_crypto_info(self):
  154. self._cur.execute("SELECT * FROM CRYPTO")
  155. row = self._cur.fetchone()
  156. return row
  157. def listtags(self):
  158. self._clean_orphans()
  159. get_tags = "select DATA from TAG"
  160. self._cur.execute(get_tags)
  161. tags = self._cur.fetchall()
  162. if tags:
  163. return [t[0] for t in tags]
  164. return [] # pragma: no cover
  165. def listnodes(self, filter=None):
  166. if not filter:
  167. sql_all = "SELECT ID FROM NODE"
  168. self._cur.execute(sql_all)
  169. ids = self._cur.fetchall()
  170. return [id[0] for id in ids]
  171. else:
  172. tagid = self._get_tag(filter)
  173. if not tagid:
  174. return [] # pragma: no cover
  175. sql_filter = "SELECT NODEID FROM LOOKUP WHERE TAGID = %s "
  176. self._cur.execute(sql_filter, (tagid))
  177. self._con.commit()
  178. ids = self._cur.fetchall()
  179. return [id[0] for id in ids]
  180. def save_crypto_info(self, seed, digest):
  181. """save the random seed and the digested key"""
  182. self._cur.execute("DELETE FROM CRYPTO")
  183. self._cur.execute("INSERT INTO CRYPTO VALUES(%s, %s)", (seed, digest))
  184. self._con.commit()
  185. def loadkey(self):
  186. sql = "SELECT * FROM CRYPTO"
  187. try:
  188. self._cur.execute(sql)
  189. seed, digest = self._cur.fetchone()
  190. return seed + u'$6$' + digest
  191. except TypeError: # pragma: no cover
  192. return None
  193. def _clean_orphans(self):
  194. clean = ("delete from TAG where not exists "
  195. "(select 'x' from LOOKUP l where l.TAGID = TAG.ID)")
  196. self._cur.execute(clean)
  197. def removenodes(self, nid):
  198. # shall we do this also in the sqlite driver?
  199. sql_clean = "DELETE FROM LOOKUP WHERE NODEID=%s"
  200. self._cur.execute(sql_clean, nid)
  201. sql_rm = "delete from NODE where ID = %s"
  202. self._cur.execute(sql_rm, nid)
  203. self._con.commit()
  204. self._con.commit()
  205. def savekey(self, key):
  206. salt, digest = key.split('$6$')
  207. sql = "INSERT INTO CRYPTO(SEED, DIGEST) VALUES(%s,%s)"
  208. self._cur.execute("DELETE FROM CRYPTO")
  209. self._cur.execute(sql, (salt, digest))
  210. self._digest = digest.encode('utf-8')
  211. self._salt = salt.encode('utf-8')
  212. self._con.commit()