[or-cvs] [bridgedb/master] Implement a sqlite replacement for our current db wrappers.

Nick Mathewson nickm at seul.org
Fri Sep 25 06:03:22 UTC 2009


Author: Nick Mathewson <nickm at torproject.org>
Date: Fri, 25 Sep 2009 01:18:38 -0400
Subject: Implement a sqlite replacement for our current db wrappers.
Commit: 1d739d1bfc7b544382066ebf9c6df7895c95cd60

Now all we'll need to do is reverse-engineer our current DB usage,
design a schema, write a migration tool, and switch the code to use
sqlite.

Such fun!
---
 lib/bridgedb/Storage.py |   92 +++++++++++++++++++++++++++++++++++++++++++++++
 lib/bridgedb/Tests.py   |   60 +++++++++++++++++++++++++++----
 2 files changed, 145 insertions(+), 7 deletions(-)
 create mode 100644 lib/bridgedb/Storage.py

diff --git a/lib/bridgedb/Storage.py b/lib/bridgedb/Storage.py
new file mode 100644
index 0000000..a0430c0
--- /dev/null
+++ b/lib/bridgedb/Storage.py
@@ -0,0 +1,92 @@
+# BridgeDB by Nick Mathewson.
+# Copyright (c) 2007-2009, The Tor Project, Inc.
+# See LICENSE for licensing information
+
+def _escapeValue(v):
+    return "'%s'" % v.replace("'", "''")
+
+class SqliteDict:
+    """
+       A SqliteDict wraps a SQLite table and makes it look like a
+       Python dictionary.  In addition to the single key and value
+       columns, there can be a number of "fixed" columns, such that
+       the dictionary only contains elements of the table where the
+       fixed columns are set appropriately.
+    """
+    def __init__(self, conn, cursor, table, fixedcolnames, fixedcolvalues,
+                 keycol, valcol):
+        assert len(fixedcolnames) == len(fixedcolvalues)
+        self._conn = conn
+        self._cursor = cursor
+        keys = ", ".join(fixedcolnames+(keycol,valcol))
+        vals = "".join("%s, "%_escapeValue(v) for v in fixedcolvalues)
+        constraint = "WHERE %s = ?"%keycol
+        if fixedcolnames:
+            constraint += "".join(
+                " AND %s = %s"%(c,_escapeValue(v))
+                for c,v in zip(fixedcolnames, fixedcolvalues))
+
+        self._getStmt = "SELECT %s FROM %s %s"%(valcol,table,constraint)
+        self._delStmt = "DELETE FROM %s %s"%(table,constraint)
+        self._setStmt = "INSERT OR REPLACE INTO %s (%s) VALUES (%s?, ?)"%(
+            table, keys, vals)
+
+        constraint = " AND ".join("%s = %s"%(c,_escapeValue(v))
+                for c,v in zip(fixedcolnames, fixedcolvalues))
+        if constraint:
+            whereClause = " WHERE %s"%constraint
+        else:
+            whereClause = ""
+
+        self._keysStmt = "SELECT %s FROM %s%s"%(keycol,table,whereClause)
+
+    def __setitem__(self, k, v):
+        self._cursor.execute(self._setStmt, (k,v))
+    def __delitem__(self, k):
+        self._cursor.execute(self._delStmt, (k,))
+        if self._cursor.rowcount == 0:
+            raise KeyError(k)
+    def __getitem__(self, k):
+        self._cursor.execute(self._getStmt, (k,))
+        val = self._cursor.fetchone()
+        if val == None:
+            raise KeyError(k)
+        else:
+            return val[0]
+    def has_key(self):
+        self._cursor.execute(self._getStmt, (k,))
+        return self._cursor.rowcount != 0
+    def get(self, k, v=None):
+        self._cursor.execute(self._getStmt, (k,))
+        val = self._cursor.fetchone()
+        if val == None:
+            return v;
+        else:
+            return val[0]
+    def setdefault(self, k, v):
+        try:
+            r = self[k]
+        except KeyError:
+            r = self[k] = v
+        return r
+    def keys(self):
+        self._cursor.execute(self._keysStmt)
+        return [ key for (key,) in self._cursor.fetchall() ]
+
+    def commit(self):
+        self._conn.commit()
+    def rollback(self):
+        self._conn.rollback()
+
+#
+#  The old DB system was just a key->value mapping DB, with special key
+#  prefixes to indicate which database they fell into.
+#
+#     sp|<HEXID> -- given to bridgesplitter; maps bridgeID to ring name.
+#     em|<emailaddr> -- given to emailbaseddistributor; maps email address
+#            to concatenated hexID.
+#     fs|<HEXID> -- Given to BridgeTracker, maps to time when a router was
+#            first seen (YYYY-MM-DD HH:MM)
+#     ls|<HEXID> -- given to bridgetracker, maps to time when a router was
+#            last seen (YYYY-MM-DD HH:MM)
+#
diff --git a/lib/bridgedb/Tests.py b/lib/bridgedb/Tests.py
index 9b9b1b9..865c91e 100644
--- a/lib/bridgedb/Tests.py
+++ b/lib/bridgedb/Tests.py
@@ -3,22 +3,22 @@
 # See LICENSE for licensing information
 
 import doctest
+import os
+import random
+import sqlite3
+import tempfile
 import unittest
 import warnings
-import random
 
 import bridgedb.Bridges
 import bridgedb.Main
 import bridgedb.Dist
 import bridgedb.Time
+import bridgedb.Storage
 
 def suppressWarnings():
     warnings.filterwarnings('ignore', '.*tmpnam.*')
 
-class TestCase0(unittest.TestCase):
-    def testFooIsFooish(self):
-        self.assert_(True)
-
 def randomIP():
     return ".".join([str(random.randrange(1,256)) for _ in xrange(4)])
 
@@ -82,11 +82,57 @@ class IPBridgeDistTests(unittest.TestCase):
             self.assertEquals(len(fps), 5)
             self.assertTrue(count >= 1)
 
+class StorageTests(unittest.TestCase):
+    def setUp(self):
+        self.fd, self.fname = tempfile.mkstemp()
+        self.conn = sqlite3.Connection(self.fname)
+
+    def tearDown(self):
+        self.conn.close()
+        os.close(self.fd)
+        os.unlink(self.fname)
+
+    def testSimpleDict(self):
+        self.conn.execute("CREATE TABLE A ( X PRIMARY KEY, Y )")
+        d = bridgedb.Storage.SqliteDict(self.conn, self.conn.cursor(),
+                                        "A", (), (), "X", "Y")
+
+        self.basictests(d)
+
+    def testComplexDict(self):
+        self.conn.execute("CREATE TABLE B ( X, Y, Z, "
+                          "CONSTRAINT B_PK PRIMARY KEY (X,Y) )")
+        d = bridgedb.Storage.SqliteDict(self.conn, self.conn.cursor(),
+                                        "B", ("X",), ("x1",), "Y", "Z")
+        d2 = bridgedb.Storage.SqliteDict(self.conn, self.conn.cursor(),
+                                         "B", ("X",), ("x2",), "Y", "Z")
+        self.basictests(d)
+        self.basictests(d2)
+
+    def basictests(self, d):
+        d["hello"] = "goodbye"
+        d["hola"] = "adios"
+        self.assertEquals(d["hola"], "adios")
+        d["hola"] = "hasta luego"
+        self.assertEquals(d["hola"], "hasta luego")
+        self.assertEquals(sorted(d.keys()), [u"hello", u"hola"])
+        self.assertRaises(KeyError, d.__getitem__, "buongiorno")
+        self.assertEquals(d.get("buongiorno", "ciao"), "ciao")
+        self.conn.commit()
+        d["buongiorno"] = "ciao"
+        del d['hola']
+        self.assertRaises(KeyError, d.__getitem__, "hola")
+        self.conn.rollback()
+        self.assertEquals(d["hola"], "hasta luego")
+        self.assertEquals(d.setdefault("hola","bye"), "hasta luego")
+        self.assertEquals(d.setdefault("yo","bye"), "bye")
+        self.assertEquals(d['yo'], "bye")
+
 def testSuite():
     suite = unittest.TestSuite()
     loader = unittest.TestLoader()
 
-    for klass in [ TestCase0, IPBridgeDistTests ]:
+    for klass in [ IPBridgeDistTests, StorageTests ]:
         suite.addTest(loader.loadTestsFromTestCase(klass))
 
     for module in [ bridgedb.Bridges,
@@ -99,7 +145,7 @@ def testSuite():
 
 def main():
     suppressWarnings()
-    
+
     unittest.TextTestRunner(verbosity=1).run(testSuite())
 
 
-- 
1.5.6.5




More information about the tor-commits mailing list