[tor-commits] [bridgedb/develop] Use the new idiom throughout code

isis at torproject.org isis at torproject.org
Tue Apr 1 22:16:43 UTC 2014


commit 63f046df60767bae1a2ec31e67346c4ebb6769f8
Author: Matthew Finkel <Matthew.Finkel at gmail.com>
Date:   Fri Mar 21 03:10:43 2014 +0000

    Use the new idiom throughout code
---
 lib/bridgedb/Bridges.py   |   66 +++++++--------
 lib/bridgedb/Bucket.py    |  104 +++++++++++------------
 lib/bridgedb/Dist.py      |  115 ++++++++++++-------------
 lib/bridgedb/Stability.py |  204 ++++++++++++++++++++++-----------------------
 lib/bridgedb/Tests.py     |    4 +-
 5 files changed, 247 insertions(+), 246 deletions(-)

diff --git a/lib/bridgedb/Bridges.py b/lib/bridgedb/Bridges.py
index 2fb784a..bfe1316 100644
--- a/lib/bridgedb/Bridges.py
+++ b/lib/bridgedb/Bridges.py
@@ -308,38 +308,38 @@ class Bridge(object):
         A bridge is 'familiar' if 1/8 of all active bridges have appeared
         more recently than it, or if it has been around for a Weighted Time of 8 days.
         """
-        db = bridgedb.Storage.getDB()
-        return db.getBridgeHistory(self.fingerprint).familiar
+        with bridgedb.Storage.getDB() as db:
+            return db.getBridgeHistory(self.fingerprint).familiar
 
     @property
     def wfu(self):
         """Weighted Fractional Uptime"""
-        db = bridgedb.Storage.getDB()
-        return db.getBridgeHistory(self.fingerprint).weightedFractionalUptime
+        with bridgedb.Storage.getDB() as db:
+            return db.getBridgeHistory(self.fingerprint).weightedFractionalUptime
 
     @property
     def weightedTime(self):
         """Weighted Time"""
-        db = bridgedb.Storage.getDB()
-        return db.getBridgeHistory(self.fingerprint).weightedTime
+        with bridgedb.Storage.getDB() as db:
+            return db.getBridgeHistory(self.fingerprint).weightedTime
 
     @property
     def wmtbac(self):
         """Weighted Mean Time Between Address Change"""
-        db = bridgedb.Storage.getDB()
-        return db.getBridgeHistory(self.fingerprint).wmtbac
+        with bridgedb.Storage.getDB() as db:
+            return db.getBridgeHistory(self.fingerprint).wmtbac
 
     @property
     def tosa(self):
         """the Time On Same Address (TOSA)"""
-        db = bridgedb.Storage.getDB()
-        return db.getBridgeHistory(self.fingerprint).tosa
+        with bridgedb.Storage.getDB() as db:
+            return db.getBridgeHistory(self.fingerprint).tosa
 
     @property
     def weightedUptime(self):
         """Weighted Uptime"""
-        db = bridgedb.Storage.getDB()
-        return db.getBridgeHistory(self.fingerprint).weightedUptime
+        with bridgedb.Storage.getDB() as db:
+            return db.getBridgeHistory(self.fingerprint).weightedUptime
 
 def getDescriptorDigests(desc):
     """Return the SHA-1 hash hexdigests of all descriptor descs
@@ -1122,19 +1122,19 @@ class UnallocatedHolder(BridgeHolder):
         self.fingerprints = []
 
     def dumpAssignments(self, f, description=""):
-        db = bridgedb.Storage.getDB()
-        allBridges = db.getAllBridges()
-        for bridge in allBridges:
-            if bridge.hex_key not in self.fingerprints:
-                continue
-            dist = bridge.distributor
-            desc = [ description ]
-            if dist.startswith(bridgedb.Bucket.PSEUDO_DISTRI_PREFIX):
-                dist = dist.replace(bridgedb.Bucket.PSEUDO_DISTRI_PREFIX, "")
-                desc.append("bucket=%s" % dist)
-            elif dist != "unallocated":
-                continue
-            f.write("%s %s\n" % (bridge.hex_key, " ".join(desc).strip()))
+        with bridgedb.Storage.getDB() as db:
+            allBridges = db.getAllBridges()
+            for bridge in allBridges:
+                if bridge.hex_key not in self.fingerprints:
+                    continue
+                dist = bridge.distributor
+                desc = [ description ]
+                if dist.startswith(bridgedb.Bucket.PSEUDO_DISTRI_PREFIX):
+                    dist = dist.replace(bridgedb.Bucket.PSEUDO_DISTRI_PREFIX, "")
+                    desc.append("bucket=%s" % dist)
+                elif dist != "unallocated":
+                    continue
+                f.write("%s %s\n" % (bridge.hex_key, " ".join(desc).strip()))
 
 class BridgeSplitter(BridgeHolder):
     """A BridgeHolder that splits incoming bridges up based on an hmac,
@@ -1186,7 +1186,6 @@ class BridgeSplitter(BridgeHolder):
 
     def insert(self, bridge):
         assert self.rings
-        db = bridgedb.Storage.getDB()
 
         for s in self.statsHolders:
             s.insert(bridge)
@@ -1205,16 +1204,17 @@ class BridgeSplitter(BridgeHolder):
 
         validRings = self.rings + self.pseudoRings
 
-        ringname = db.insertBridgeAndGetRing(bridge, ringname, time.time(), 
+        with bridgedb.Storage.getDB() as db:
+            ringname = db.insertBridgeAndGetRing(bridge, ringname, time.time(), 
                                              validRings)
-        db.commit()
+            db.commit()
 
-        # Pseudo distributors are always held in the "unallocated" ring
-        if ringname in self.pseudoRings:
-            ringname = "unallocated"
+            # Pseudo distributors are always held in the "unallocated" ring
+            if ringname in self.pseudoRings:
+                ringname = "unallocated"
 
-        ring = self.ringsByName.get(ringname)
-        ring.insert(bridge)
+            ring = self.ringsByName.get(ringname)
+            ring.insert(bridge)
 
     def dumpAssignments(self, f, description=""):
         for name,ring in self.ringsByName.iteritems():
diff --git a/lib/bridgedb/Bucket.py b/lib/bridgedb/Bucket.py
index 0038146..416c1ea 100644
--- a/lib/bridgedb/Bucket.py
+++ b/lib/bridgedb/Bucket.py
@@ -109,22 +109,18 @@ class BucketManager:
         self.unallocatedList = []
         self.unallocated_available = False
         self.distributor_prefix = PSEUDO_DISTRI_PREFIX
-        self.db = bridgedb.Storage.Database(self.cfg.DB_FILE+".sqlite",
-                                            self.cfg.DB_FILE)
-
-    def __del__(self):
-        self.db.close()
 
     def addToUnallocatedList(self, hex_key):
         """Add a bridge by hex_key into the unallocated pool
         """
-        try:
-            self.db.updateDistributorForHexKey("unallocated", hex_key)
-        except:
-            self.db.rollback()
-            raise
-        else:
-            self.db.commit()
+        with bridgedb.Storage.getDB() as db:
+            try:
+                db.updateDistributorForHexKey("unallocated", hex_key)
+            except:
+                db.rollback()
+                raise
+            else:
+                db.commit()
         self.unallocatedList.append(hex_key)
         self.unallocated_available = True
 
@@ -144,17 +140,18 @@ class BucketManager:
         # Mark pseudo-allocators in the database as such
         allocator_name = bucket.name
         #print "KEY: %d NAME: %s" % (hex_key, allocator_name)
-        try:
-            self.db.updateDistributorForHexKey(allocator_name, hex_key)
-        except:
-            self.db.rollback()
-            # Ok, this seems useless, but for consistancy's sake, we'll 
-            # re-assign the bridge from this missed db update attempt to the
-            # unallocated list. Remember? We pop()'d it before.
-            self.addToUnallocatedList(hex_key)
-            raise
-        else:
-            self.db.commit()
+        with bridgedb.Storage.getDB() as db:
+            try:
+                db.updateDistributorForHexKey(allocator_name, hex_key)
+            except:
+                db.rollback()
+                # Ok, this seems useless, but for consistancy's sake, we'll
+                # re-assign the bridge from this missed db update attempt to the
+                # unallocated list. Remember? We pop()'d it before.
+                self.addToUnallocatedList(hex_key)
+                raise
+            else:
+                db.commit()
         bucket.allocated += 1
         if len(self.unallocatedList) < 1:
             self.unallocated_available = False
@@ -171,7 +168,8 @@ class BucketManager:
             self.bucketList.append(d)
 
         # Loop through all bridges and sort out distributors
-        allBridges = self.db.getAllBridges()
+        with bridgedb.Storage.getDB() as db:
+            allBridges = db.getAllBridges()
         for bridge in allBridges:
             if bridge.distributor == "unallocated":
                 self.addToUnallocatedList(bridge.hex_key)
@@ -215,38 +213,40 @@ class BucketManager:
         logging.debug("Dumping bridge assignments to file: %r" % filename)
         # get the bridge histories and sort by Time On Same Address
         bridgeHistories = []
-        for b in bridges:
-            bh = self.db.getBridgeHistory(b.hex_key)
-            if bh: bridgeHistories.append(bh)
-        bridgeHistories.sort(lambda x,y: cmp(x.weightedFractionalUptime,
-            y.weightedFractionalUptime))
-
-        # for a bridge, get the list of countries it might not work in
-        blocklist = dict()
-        if getattr(self.cfg, "COUNTRY_BLOCK_FILE", None) is not None:
-            f = open(self.cfg.COUNTRY_BLOCK_FILE, 'r')
-            for ID,address,portlist,countries in bridgedb.Bridges.parseCountryBlockFile(f):
-                blocklist[toHex(ID)] = countries
-            f.close()
-
-        try:
-            f = open(filename, 'w')
-            for bh in bridgeHistories:
-                days = bh.tosa / long(60*60*24)
-                line = "%s:%s\t(%d %s)" %  \
-                        (bh.ip, bh.port, days,  _("""days at this address"""))
-                if str(bh.fingerprint) in blocklist.keys():
-                    line = line + "\t%s: (%s)" % (_("""(Might be blocked)"""),
-                            ",".join(blocklist[bh.fingerprint]),)
-                f.write(line + '\n')
-            f.close()
-        except IOError:
-            print "I/O error: %s" % filename
+        with bridgedb.Storage.getDB() as db:
+            for b in bridges:
+                bh = db.getBridgeHistory(b.hex_key)
+                if bh: bridgeHistories.append(bh)
+            bridgeHistories.sort(lambda x,y: cmp(x.weightedFractionalUptime,
+                y.weightedFractionalUptime))
+
+            # for a bridge, get the list of countries it might not work in
+            blocklist = dict()
+            if getattr(self.cfg, "COUNTRY_BLOCK_FILE", None) is not None:
+                f = open(self.cfg.COUNTRY_BLOCK_FILE, 'r')
+                for ID,address,portlist,countries in bridgedb.Bridges.parseCountryBlockFile(f):
+                    blocklist[toHex(ID)] = countries
+                f.close()
+
+            try:
+                f = open(filename, 'w')
+                for bh in bridgeHistories:
+                    days = bh.tosa / long(60*60*24)
+                    line = "%s:%s\t(%d %s)" %  \
+                            (bh.ip, bh.port, days,  _("""days at this address"""))
+                    if str(bh.fingerprint) in blocklist.keys():
+                        line = line + "\t%s: (%s)" % (_("""(Might be blocked)"""),
+                                ",".join(blocklist[bh.fingerprint]),)
+                    f.write(line + '\n')
+                f.close()
+            except IOError:
+                print "I/O error: %s" % filename
 
     def dumpBridges(self):
         """Dump all known file distributors to files, sort by distributor
         """
-        allBridges = self.db.getAllBridges()
+        with bridgedb.Storage.getDB() as db:
+            allBridges = db.getAllBridges()
         bridgeDict = {}
         # Sort returned bridges by distributor
         for bridge in allBridges:
diff --git a/lib/bridgedb/Dist.py b/lib/bridgedb/Dist.py
index dddc84b..3bfb2d8 100644
--- a/lib/bridgedb/Dist.py
+++ b/lib/bridgedb/Dist.py
@@ -474,58 +474,59 @@ class EmailBasedDistributor(Distributor):
         if emailaddress is None:
             return [] #XXXX raise an exception.
 
-        db = bridgedb.Storage.getDB()
-        wasWarned = db.getWarnedEmail(emailaddress)
-        lastSaw = db.getEmailTime(emailaddress)
-
-        logging.info("Attempting to return for %d bridges for %s..."
-                     % (N, Util.logSafely(emailaddress)))
-
-        if lastSaw is not None and lastSaw + MAX_EMAIL_RATE >= now:
-            logging.info("Client %s sent duplicate request within %d seconds."
-                         % (Util.logSafely(emailaddress), MAX_EMAIL_RATE))
-            if wasWarned:
-                logging.info(
-                    "Client was already warned about duplicate requests.")
-                raise IgnoreEmail("Client was warned",
-                                  Util.logSafely(emailaddress))
+        with bridgedb.Storage.getDB() as db:
+            wasWarned = db.getWarnedEmail(emailaddress)
+            lastSaw = db.getEmailTime(emailaddress)
+
+            logging.info("Attempting to return for %d bridges for %s..."
+                         % (N, Util.logSafely(emailaddress)))
+
+            if lastSaw is not None and lastSaw + MAX_EMAIL_RATE >= now:
+                logging.info("Client %s sent duplicate request within %d seconds."
+                             % (Util.logSafely(emailaddress), MAX_EMAIL_RATE))
+                if wasWarned:
+                    logging.info(
+                        "Client was already warned about duplicate requests.")
+                    raise IgnoreEmail("Client was warned",
+                                      Util.logSafely(emailaddress))
+                else:
+                    logging.info("Sending duplicate request warning to %s..."
+                                 % Util.logSafely(emailaddress))
+                    db.setWarnedEmail(emailaddress, True, now)
+                    db.commit()
+
+                raise TooSoonEmail("Too many emails; wait till later", emailaddress)
+
+            # warning period is over
+            elif wasWarned:
+                db.setWarnedEmail(emailaddress, False)
+
+            pos = self.emailHmac("<%s>%s" % (epoch, emailaddress))
+
+            ring = None
+            ruleset = frozenset(bridgeFilterRules)
+            if ruleset in self.splitter.filterRings.keys():
+                logging.debug("Cache hit %s" % ruleset)
+                _, ring = self.splitter.filterRings[ruleset]
             else:
-                logging.info("Sending duplicate request warning to %s..."
-                             % Util.logSafely(emailaddress))
-                db.setWarnedEmail(emailaddress, True, now)
-                db.commit()
-
-            raise TooSoonEmail("Too many emails; wait till later", emailaddress)
-
-        # warning period is over
-        elif wasWarned:
-            db.setWarnedEmail(emailaddress, False)
-
-        pos = self.emailHmac("<%s>%s" % (epoch, emailaddress))
-
-        ring = None
-        ruleset = frozenset(bridgeFilterRules)
-        if ruleset in self.splitter.filterRings.keys():
-            logging.debug("Cache hit %s" % ruleset)
-            _, ring = self.splitter.filterRings[ruleset]
-        else:
-            # cache miss, add new ring
-            logging.debug("Cache miss %s" % ruleset)
+                # cache miss, add new ring
+                logging.debug("Cache miss %s" % ruleset)
 
-            # add new ring
-            key1 = getHMAC(self.splitter.key, "Order-Bridges-In-Ring")
-            ring = bridgedb.Bridges.BridgeRing(key1, self.answerParameters)
-            # debug log: cache miss
-            self.splitter.addRing(ring, ruleset,
-                                  filterBridgesByRules(bridgeFilterRules),
-                                  populate_from=self.splitter.bridges)
+                # add new ring
+                key1 = getHMAC(self.splitter.key,
+                                                 "Order-Bridges-In-Ring")
+                ring = bridgedb.Bridges.BridgeRing(key1, self.answerParameters)
+                # debug log: cache miss
+                self.splitter.addRing(ring, ruleset,
+                                      filterBridgesByRules(bridgeFilterRules),
+                                      populate_from=self.splitter.bridges)
 
-        numBridgesToReturn = getNumBridgesPerAnswer(ring,
-                                                    max_bridges_per_answer=N)
-        result = ring.getBridges(pos, numBridgesToReturn)
+            numBridgesToReturn = getNumBridgesPerAnswer(ring,
+                                                        max_bridges_per_answer=N)
+            result = ring.getBridges(pos, numBridgesToReturn)
 
-        db.setEmailTime(emailaddress, now)
-        db.commit()
+            db.setEmailTime(emailaddress, now)
+            db.commit()
 
         return result
 
@@ -533,15 +534,15 @@ class EmailBasedDistributor(Distributor):
         return len(self.splitter)
 
     def cleanDatabase(self):
-        db = bridgedb.Storage.getDB()
-        try:
-            db.cleanEmailedBridges(time.time()-MAX_EMAIL_RATE)
-            db.cleanWarnedEmails(time.time()-MAX_EMAIL_RATE)
-        except:
-            db.rollback()
-            raise
-        else:
-            db.commit()
+        with bridgedb.Storage.getDB() as db:
+            try:
+                db.cleanEmailedBridges(time.time()-MAX_EMAIL_RATE)
+                db.cleanWarnedEmails(time.time()-MAX_EMAIL_RATE)
+            except:
+                db.rollback()
+                raise
+            else:
+                db.commit()
 
     def dumpAssignments(self, f, description=""):
         self.splitter.dumpAssignments(f, description)
diff --git a/lib/bridgedb/Stability.py b/lib/bridgedb/Stability.py
index 5cd7e8e..4c3777d 100644
--- a/lib/bridgedb/Stability.py
+++ b/lib/bridgedb/Stability.py
@@ -111,14 +111,14 @@ class BridgeHistory(object):
 
         # return True if self.weightedTime is greater than the weightedTime
         # of the > bottom 1/8 all bridges, sorted by weightedTime
-        db = bridgedb.Storage.getDB()
-        allWeightedTimes = [ bh.weightedTime for bh in db.getAllBridgeHistory()]
-        numBridges = len(allWeightedTimes)
-        logging.debug("Got %d weightedTimes", numBridges)
-        allWeightedTimes.sort()
-        if self.weightedTime >= allWeightedTimes[numBridges/8]:
-            return True
-        return False
+        with bridgedb.Storage.getDB() as db:
+            allWeightedTimes = [ bh.weightedTime for bh in db.getAllBridgeHistory()]
+            numBridges = len(allWeightedTimes)
+            logging.debug("Got %d weightedTimes", numBridges)
+            allWeightedTimes.sort()
+            if self.weightedTime >= allWeightedTimes[numBridges/8]:
+                return True
+            return False
 
     @property
     def wmtbac(self):
@@ -134,104 +134,104 @@ class BridgeHistory(object):
         return totalRunlength / totalWeights
 
 def addOrUpdateBridgeHistory(bridge, timestamp):
-    db = bridgedb.Storage.getDB()
-    bhe = db.getBridgeHistory(bridge.fingerprint)
-    if not bhe:
-        # This is the first status, assume 60 minutes.
-        secondsSinceLastStatusPublication = long(60*60)
-        lastSeenWithDifferentAddressAndPort = timestamp * long(1000)
-        lastSeenWithThisAddressAndPort = timestamp * long(1000)
-    
-        bhe = BridgeHistory(
-                bridge.fingerprint, bridge.ip, bridge.orport, 
-                0,#weightedUptime
-                0,#weightedTime
-                0,#weightedRunLength
-                0,# totalRunWeights
-                lastSeenWithDifferentAddressAndPort, # first timestamnp
-                lastSeenWithThisAddressAndPort,
-                0,#lastDiscountedHistoryValues,
-                0,#lastUpdatedWeightedTime
-                )
-        # first time we have seen this descriptor
-        db.updateIntoBridgeHistory(bhe)
-    # Calculate the seconds since the last parsed status.  If this is
-    # the first status or we haven't seen a status for more than 60
-    # minutes, assume 60 minutes.
-    statusPublicationMillis = long(timestamp * 1000)
-    if (statusPublicationMillis - bhe.lastSeenWithThisAddressAndPort) > 60*60*1000:
-        secondsSinceLastStatusPublication = long(60*60)
-        logging.debug("Capping secondsSinceLastStatusPublication to 1 hour")    
-    # otherwise, roll with it
-    else:
-        secondsSinceLastStatusPublication = \
-                (statusPublicationMillis - bhe.lastSeenWithThisAddressAndPort)/1000
-    if secondsSinceLastStatusPublication <= 0 and bhe.weightedTime > 0:
-        # old descriptor, bail
-        logging.warn("Received old descriptor for bridge %s with timestamp %d",
-                bhe.fingerprint, statusPublicationMillis/1000)
-        return bhe
+    with bridgedb.Storage.getDB() as db:
+        bhe = db.getBridgeHistory(bridge.fingerprint)
+        if not bhe:
+            # This is the first status, assume 60 minutes.
+            secondsSinceLastStatusPublication = long(60*60)
+            lastSeenWithDifferentAddressAndPort = timestamp * long(1000)
+            lastSeenWithThisAddressAndPort = timestamp * long(1000)
     
-    # iterate over all known bridges and apply weighting factor
-    discountAndPruneBridgeHistories(statusPublicationMillis)
+            bhe = BridgeHistory(
+                    bridge.fingerprint, bridge.ip, bridge.orport,
+                    0,#weightedUptime
+                    0,#weightedTime
+                    0,#weightedRunLength
+                    0,# totalRunWeights
+                    lastSeenWithDifferentAddressAndPort, # first timestamnp
+                    lastSeenWithThisAddressAndPort,
+                    0,#lastDiscountedHistoryValues,
+                    0,#lastUpdatedWeightedTime
+                    )
+            # first time we have seen this descriptor
+            db.updateIntoBridgeHistory(bhe)
+        # Calculate the seconds since the last parsed status.  If this is
+        # the first status or we haven't seen a status for more than 60
+        # minutes, assume 60 minutes.
+        statusPublicationMillis = long(timestamp * 1000)
+        if (statusPublicationMillis - bhe.lastSeenWithThisAddressAndPort) > 60*60*1000:
+            secondsSinceLastStatusPublication = long(60*60)
+            logging.debug("Capping secondsSinceLastStatusPublication to 1 hour")
+        # otherwise, roll with it
+        else:
+            secondsSinceLastStatusPublication = \
+                    (statusPublicationMillis - bhe.lastSeenWithThisAddressAndPort)/1000
+        if secondsSinceLastStatusPublication <= 0 and bhe.weightedTime > 0:
+            # old descriptor, bail
+            logging.warn("Received old descriptor for bridge %s with timestamp %d",
+                    bhe.fingerprint, statusPublicationMillis/1000)
+            return bhe
     
-    # Update the weighted times of bridges
-    updateWeightedTime(statusPublicationMillis)
-
-    # For Running Bridges only:
-    # compare the stored history against the descriptor and see if the
-    # bridge has changed its address or port
-    bhe = db.getBridgeHistory(bridge.fingerprint)
-
-    if not bridge.running:
-        logging.info("%s is not running" % bridge.fingerprint)
-        return bhe
-
-    # Parse the descriptor and see if the address or port changed
-    # If so, store the weighted run time
-    if bridge.orport != bhe.port or bridge.ip != bhe.ip:
-        bhe.totalRunWeights += 1.0;
-        bhe.weightedRunLength += bhe.tosa
-        bhe.lastSeenWithDifferentAddressAndPort =\
-                bhe.lastSeenWithThisAddressAndPort
-
-    # Regardless of whether the bridge is new, kept or changed
-    # its address and port, raise its WFU times and note its
-    # current address and port, and that we saw it using them.
-    bhe.weightedUptime += secondsSinceLastStatusPublication
-    bhe.lastSeenWithThisAddressAndPort = statusPublicationMillis
-    bhe.ip = str(bridge.ip)
-    bhe.port = bridge.orport
-    return db.updateIntoBridgeHistory(bhe)
+        # iterate over all known bridges and apply weighting factor
+        discountAndPruneBridgeHistories(statusPublicationMillis)
+
+        # Update the weighted times of bridges
+        updateWeightedTime(statusPublicationMillis)
+
+        # For Running Bridges only:
+        # compare the stored history against the descriptor and see if the
+        # bridge has changed its address or port
+        bhe = db.getBridgeHistory(bridge.fingerprint)
+
+        if not bridge.running:
+            logging.info("%s is not running" % bridge.fingerprint)
+            return bhe
+
+        # Parse the descriptor and see if the address or port changed
+        # If so, store the weighted run time
+        if bridge.orport != bhe.port or bridge.ip != bhe.ip:
+            bhe.totalRunWeights += 1.0;
+            bhe.weightedRunLength += bhe.tosa
+            bhe.lastSeenWithDifferentAddressAndPort =\
+                    bhe.lastSeenWithThisAddressAndPort
+
+        # Regardless of whether the bridge is new, kept or changed
+        # its address and port, raise its WFU times and note its
+        # current address and port, and that we saw it using them.
+        bhe.weightedUptime += secondsSinceLastStatusPublication
+        bhe.lastSeenWithThisAddressAndPort = statusPublicationMillis
+        bhe.ip = str(bridge.ip)
+        bhe.port = bridge.orport
+        return db.updateIntoBridgeHistory(bhe)
 
 def discountAndPruneBridgeHistories(discountUntilMillis):
-    db = bridgedb.Storage.getDB()
-    bhToRemove = []
-    bhToUpdate = []
-
-    for bh in db.getAllBridgeHistory():
-        # discount previous values by factor of 0.95 every 12 hours
-        bh.discountWeightedFractionalUptimeAndWeightedTime(discountUntilMillis)
-        # give the thing at least 24 hours before pruning it
-        if bh.weightedFractionalUptime < 1 and bh.weightedTime > 60*60*24:
-            logging.debug("Removing bridge from history: %s" % bh.fingerprint)
-            bhToRemove.append(bh.fingerprint)
-        else:
-            bhToUpdate.append(bh)
-
-    for k in bhToUpdate: db.updateIntoBridgeHistory(k)
-    for k in bhToRemove: db.delBridgeHistory(k)
+    with bridgedb.Storage.getDB() as db:
+        bhToRemove = []
+        bhToUpdate = []
+
+        for bh in db.getAllBridgeHistory():
+            # discount previous values by factor of 0.95 every 12 hours
+            bh.discountWeightedFractionalUptimeAndWeightedTime(discountUntilMillis)
+            # give the thing at least 24 hours before pruning it
+            if bh.weightedFractionalUptime < 1 and bh.weightedTime > 60*60*24:
+                logging.debug("Removing bridge from history: %s" % bh.fingerprint)
+                bhToRemove.append(bh.fingerprint)
+            else:
+                bhToUpdate.append(bh)
+
+        for k in bhToUpdate: db.updateIntoBridgeHistory(k)
+        for k in bhToRemove: db.delBridgeHistory(k)
 
 def updateWeightedTime(statusPublicationMillis):
     bhToUpdate = []
-    db = bridgedb.Storage.getDB()
-    for bh in db.getBridgesLastUpdatedBefore(statusPublicationMillis):
-        interval = (statusPublicationMillis - bh.lastUpdatedWeightedTime)/1000
-        if interval > 0:
-            bh.weightedTime += min(3600,interval) # cap to 1hr
-            bh.lastUpdatedWeightedTime = statusPublicationMillis
-            #db.updateIntoBridgeHistory(bh)
-            bhToUpdate.append(bh)
-
-    for bh in bhToUpdate:
-        db.updateIntoBridgeHistory(bh)
+    with bridgedb.Storage.getDB() as db:
+        for bh in db.getBridgesLastUpdatedBefore(statusPublicationMillis):
+            interval = (statusPublicationMillis - bh.lastUpdatedWeightedTime)/1000
+            if interval > 0:
+                bh.weightedTime += min(3600,interval) # cap to 1hr
+                bh.lastUpdatedWeightedTime = statusPublicationMillis
+                #db.updateIntoBridgeHistory(bh)
+                bhToUpdate.append(bh)
+
+        for bh in bhToUpdate:
+            db.updateIntoBridgeHistory(bh)
diff --git a/lib/bridgedb/Tests.py b/lib/bridgedb/Tests.py
index 378f1ae..72dfe5e 100644
--- a/lib/bridgedb/Tests.py
+++ b/lib/bridgedb/Tests.py
@@ -206,7 +206,7 @@ class EmailBridgeDistTests(unittest.TestCase):
     def setUp(self):
         self.fd, self.fname = tempfile.mkstemp()
         self.db = bridgedb.Storage.Database(self.fname)
-        bridgedb.Storage.setGlobalDB(self.db)
+        bridgedb.Storage.setDB(self.db)
         self.cur = self.db._conn.cursor()
 
     def tearDown(self):
@@ -672,7 +672,7 @@ class BridgeStabilityTests(unittest.TestCase):
     def setUp(self):
         self.fd, self.fname = tempfile.mkstemp()
         self.db = bridgedb.Storage.Database(self.fname)
-        bridgedb.Storage.setGlobalDB(self.db)
+        bridgedb.Storage.setDB(self.db)
         self.cur = self.db._conn.cursor()
 
     def tearDown(self):





More information about the tor-commits mailing list