mysql.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. # ============================================================================
  2. # This file is part of Pwman3.
  3. #
  4. # Pwman3 is free software; you can redistribute it and/or modify
  5. # it under the terms of the GNU General Public License, version 2
  6. # as published by the Free Software Foundation;
  7. #
  8. # Pwman3 is distributed in the hope that it will be useful,
  9. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  10. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  11. # GNU General Public License for more details.
  12. #
  13. # You should have received a copy of the GNU General Public License
  14. # along with Pwman3; if not, write to the Free Software
  15. # Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
  16. # ============================================================================
  17. # Copyright (C) 2012-2015 Oz Nahum <nahumoz@gmail.com>
  18. # ============================================================================
  19. #mysql -u root -p
  20. #create database pwmantest
  21. #create user 'pwman'@'localhost' IDENTIFIED BY '123456';
  22. #grant all on pwmantest.* to 'pwman'@'localhost';
  23. """MySQL Database implementation."""
  24. from __future__ import print_function
  25. from pwman.data.database import Database, __DB_FORMAT__
  26. import MySQLdb as mysql
  27. class MySQLDatabase(Database):
  28. @classmethod
  29. def check_db_version(cls, dburi):
  30. port = 3306
  31. credentials, host = dburi.netloc.split('@')
  32. user, passwd = credentials.split(':')
  33. if ':' in host:
  34. host, port = host.split(':')
  35. port = int(port)
  36. con = mysql.connect(host=host, port=port, user=user, passwd=passwd,
  37. db=dburi.path.lstrip('/'))
  38. cur = con.cursor()
  39. try:
  40. cur.execute("SELECT VERSION FROM DBVERSION")
  41. version = cur.fetchone()
  42. cur.close()
  43. con.close()
  44. return version[-1]
  45. except mysql.ProgrammingError:
  46. con.rollback()
  47. def __init__(self, mysqluri, dbformat=__DB_FORMAT__):
  48. self._mysqluri = mysqluri
  49. self.dbversion = dbformat
  50. def _open(self):
  51. port = 3306
  52. credentials, host = self.dburi.netloc.split('@')
  53. user, passwd = credentials.split(':')
  54. if ':' in host:
  55. host, port = host.split(':')
  56. port = int(port)
  57. self._con = mysql.connect(host=host, port=port, user=user,
  58. passwd=passwd,
  59. db=self.dburi.path.lstrip('/'))
  60. self._cur = self._con.cursor()
  61. self._create_tables()
  62. def _create_tables(self):
  63. try:
  64. self._cur.execute("SELECT 1 from DBVERSION")
  65. version = self._cur.fetchone()
  66. if version:
  67. return
  68. except mysql.ProgrammingError:
  69. self._con.rollback()
  70. try:
  71. self._cur.execute("CREATE TABLE NODE(ID SERIAL PRIMARY KEY, "
  72. "USERNAME TEXT NOT NULL, "
  73. "PASSWORD TEXT NOT NULL, "
  74. "URL TEXT NOT NULL, "
  75. "NOTES TEXT NOT NULL"
  76. ")")
  77. self._cur.execute("CREATE TABLE TAG"
  78. "(ID SERIAL PRIMARY KEY,"
  79. "DATA TEXT NOT NULL UNIQUE)")
  80. self._cur.execute("CREATE TABLE LOOKUP ("
  81. "nodeid SERIAL REFERENCES NODE(ID),"
  82. "tagid SERIAL REFERENCES TAG(ID)"
  83. ")")
  84. self._cur.execute("CREATE TABLE CRYPTO "
  85. "(SEED TEXT, DIGEST TEXT)")
  86. self._cur.execute("CREATE TABLE DBVERSION("
  87. "VERSION TEXT NOT NULL DEFAULT {}"
  88. ")".format(__DB_FORMAT__))
  89. self._cur.execute("INSERT INTO DBVERSION VALUES(%s)",
  90. (self.dbversion,))
  91. self._con.commit()
  92. except mysql.ProgrammingError: # pragma: no cover
  93. self._con.rollback()