commit dac0ab6263ee157c1199e206a6b3f12fcd35fd8e Author: Cecylia Bocovich cohosh@torproject.org Date: Wed Feb 19 14:36:56 2020 -0500
Revert to using twisted for sqlite3 db --- README.md | 2 -- gettor/utils/db.py | 86 ++++++++++++++++++++++++++++-------------------------- tests/conftests.py | 1 - tests/test_db.py | 67 ------------------------------------------ 4 files changed, 45 insertions(+), 111 deletions(-)
diff --git a/README.md b/README.md index 6ad9603..f31d113 100644 --- a/README.md +++ b/README.md @@ -82,7 +82,5 @@ GetTor includes PyTest unit tests. To run the tests, first install the dependenc
``` -$ python3 scripts/create_db -n -c -o -f tests/gettor.db -$ python3 scripts/add_links_to_db -f tests/gettor.db $ pytest-3 tests/ ``` diff --git a/gettor/utils/db.py b/gettor/utils/db.py index 7c3853f..0ca11aa 100644 --- a/gettor/utils/db.py +++ b/gettor/utils/db.py @@ -9,10 +9,10 @@
from __future__ import absolute_import
-import sqlite3 from datetime import datetime
from twisted.python import log +from twisted.enterprise import adbapi
class SQLite3(object): """ @@ -20,90 +20,94 @@ class SQLite3(object): """ def __init__(self, dbname): """Constructor.""" - self.conn = sqlite3.connect(dbname) + self.dbpool = adbapi.ConnectionPool( + "sqlite3", dbname, check_same_thread=False + ) + + def query_callback(self, results=None): + """ + Query callback + Log that the database query has been executed and return results + """ + log.msg("Database query executed successfully.") + return results + + def query_errback(self, error=None): + """ + Query error callback + Logs database error + """ + if error: + log.msg("Database error: {}".format(error)) + return None
def new_request(self, id, command, service, platform, language, date, status): """ Perform a new request to the database """ - c = self.conn.cursor() query = "INSERT INTO requests VALUES(?, ?, ?, ?, ?, ?, ?)"
- c.execute(query, (id, command, platform, language, service, - date, status)) - self.conn.commit() - return + return self.dbpool.runQuery( + query, (id, command, platform, language, service, date, status) + ).addCallback(self.query_callback).addErrback(self.query_errback)
def get_requests(self, status, command, service): """ Perform a SELECT request to the database """ - c = self.conn.cursor() query = "SELECT * FROM requests WHERE service=? AND command=? AND "\ "status = ?"
- c.execute(query, (service, command, status)) - - return c.fetchall() + return self.dbpool.runQuery( + query, (service, command, status) + ).addCallback(self.query_callback).addErrback(self.query_errback)
def get_num_requests(self, id, service): """ Get number of requests for statistics """ - c = self.conn.cursor() - query = "SELECT COUNT(rowid) FROM requests WHERE id=? AND "\ - "service=?" + query = "SELECT COUNT(rowid) FROM requests WHERE id=? AND service=?"
- c.execute(query, (id, service)) - return c.fetchone()[0] + return self.dbpool.runQuery( + query, (id, service) + ).addCallback(self.query_callback).addErrback(self.query_errback)
def remove_request(self, id, service, date): """ Removes completed request record from the database """ - c = self.conn.cursor() - query = "DELETE FROM requests WHERE id=? AND service=? AND "\ - "date=?" + query = "DELETE FROM requests WHERE id=? AND service=? AND date=?"
- c.execute(query, (id, service, date)) - self.conn.commit() - return + return self.dbpool.runQuery( + query, (id, service, date) + ).addCallback(self.query_callback).addErrback(self.query_errback)
def update_stats(self, command, service, platform=None, language='en'): """ Update statistics to the database """ - c = self.conn.cursor() now_str = datetime.now().strftime("%Y%m%d") query = "INSERT INTO stats(num_requests, platform, language, command, "\ - "service, date) VALUES (1, ?, ?, ?, ?, ?) ON "\ - "CONFLICT(platform, language, command, service, date) "\ - "DO UPDATE SET num_requests=num_requests+1" + "service, date) VALUES (1, ?, ?, ?, ?, ?) ON CONFLICT(platform, "\ + "language, command, service, date) DO UPDATE SET num_requests=num_requests+1"
- c.execute(query, (platform, language, command, service, - now_str)) - self.conn.commit() - return + return self.dbpool.runQuery( + query, (platform, language, command, service, now_str) + ).addCallback(self.query_callback).addErrback(self.query_errback)
def get_links(self, platform, language, status): """ Get links from the database per platform """ - c = self.conn.cursor() query = "SELECT * FROM links WHERE platform=? AND language=? AND status=?" - c.execute(query, (platform, language, status)) - - return c.fetchall() + return self.dbpool.runQuery( + query, (platform, language, status) + ).addCallback(self.query_callback).addErrback(self.query_errback)
def get_locales(self): """ Get a list of the supported tor browser binary locales """ - c = self.conn.cursor() query = "SELECT DISTINCT language FROM links" - c.execute(query) - - locales = [] - for locale in c.fetchall(): - locales.append(locale[0]) - return locales + return self.dbpool.runQuery(query + ).addCallback(self.query_callback).addErrback(self.query_errback) diff --git a/tests/conftests.py b/tests/conftests.py index d509776..cbb4d28 100644 --- a/tests/conftests.py +++ b/tests/conftests.py @@ -5,7 +5,6 @@ from __future__ import unicode_literals from gettor.utils import options from gettor.utils import strings from gettor.utils import twitter -from gettor.utils.db import SQLite3 from gettor.services.email.sendmail import Sendmail from gettor.services.twitter import twitterdm from gettor.parse.email import EmailParser, AddressError, DKIMError diff --git a/tests/test_db.py b/tests/test_db.py deleted file mode 100644 index d663d89..0000000 --- a/tests/test_db.py +++ /dev/null @@ -1,67 +0,0 @@ -#!/usr/bin/env python3 -import pytest -from datetime import datetime -from twisted.trial import unittest - -from . import conftests - -class DatabaseTests(unittest.TestCase): - - # Fail any tests which take longer than 15 seconds. - timeout = 15 - def setUp(self): - self.settings = conftests.options.parse_settings("en","tests/test.conf.json") - print(self.settings.get("dbname")) - self.db = conftests.SQLite3(self.settings.get("dbname")) - - def tearDown(self): - print("tearDown()") - - def add_dummy_requests(self, num): - now_str = datetime.now().strftime("%Y%m%d") - for i in (0, num): - self.db.new_request( - id='testid', - command='links', - platform='linux', - language='en', - service='email', - date=now_str, - status="ONHOLD", - ) - - def test_stored_locales(self): - locales = self.db.get_locales() - self.assertIn('en-US', locales) - - def test_requests(self): - now_str = datetime.now().strftime("%Y%m%d") - self.add_dummy_requests(2) - num = self.db.get_num_requests("testid", "email") - self.assertEqual(num, 2) - - requests = self.db.get_requests("ONHOLD", "links", "email") - for request in requests: - print(request) - self.assertEqual(request[1], "links") - self.assertEqual(request[4], "email") - self.assertEqual(request[5], now_str) - self.assertEqual(request[6], "ONHOLD") - self.assertEqual(len(requests), 2) - - self.db.remove_request("testid", "email", now_str) - num = self.db.get_num_requests("testid", "email") - self.assertEqual(num, 0) - - def test_links(self): - links = self.db.get_links("linux", "en-US", "ACTIVE") - self.assertEqual(len(links), 2) # Right now we have github and gitlab - - for link in links: - self.assertEqual(link[1], "linux") - self.assertEqual(link[2], "en-US") - self.assertEqual(link[6], "ACTIVE") - self.assertIn(link[5], ["github", "gitlab"]) - -if __name__ == "__main__": - unittest.main()
tor-commits@lists.torproject.org