[tor-commits] [flashproxy/master] drop registrations that hit _maxserve, and fix tests to run properly

infinity0 at torproject.org infinity0 at torproject.org
Mon Oct 28 14:47:41 UTC 2013


commit d658db1f29d698a7820ef0d4bdf1120e01fb35af
Author: Ximin Luo <infinity0 at gmx.com>
Date:   Wed Oct 9 23:49:38 2013 +0100

    drop registrations that hit _maxserve, and fix tests to run properly
    - the old behaviour was based on an incorrect understanding of the previous iteration of the code
---
 facilitator/facilitator      |   20 ++++++++++----------
 facilitator/facilitator-test |   23 +++++++++++++++++++----
 2 files changed, 29 insertions(+), 14 deletions(-)

diff --git a/facilitator/facilitator b/facilitator/facilitator
index f3ae79b..07038d1 100755
--- a/facilitator/facilitator
+++ b/facilitator/facilitator
@@ -173,17 +173,18 @@ class Endpoints(object):
     def _serveReg(self, addrpool):
         """
         :param list addrpool: List of candidate addresses.
-        :returns: An address of an endpoint from the given pool, or None if all
-            endpoints have already been served _maxserve times. The serve
-            counter for that address is also incremented.
+        :returns: An address of an endpoint from the given pool. The serve
+            counter for that address is also incremented, and if it hits
+            self._maxserve the endpoint is removed from this collection.
+        :raises: KeyError if any address is not registered with this collection
         """
-        if not addrpool: return None
+        if not addrpool: raise ValueError("gave empty address pool")
         prio_addr = min(addrpool, key=lambda a: self._served[a])
-        if self._served[prio_addr] < self._maxserve:
-            self._served[prio_addr] += 1
-            return prio_addr
-        else:
-            return None
+        assert self._served[prio_addr] < self._maxserve
+        self._served[prio_addr] += 1
+        if self._served[prio_addr] == self._maxserve:
+            self.delEndpoint(prio_addr)
+        return prio_addr
 
     def _ensureIndexForSuffix(self, suf):
         if suf in self._indexes: return
@@ -227,7 +228,6 @@ class Endpoints(object):
             assert all(client_pre.itervalues()) # no pool is empty
             pre = min(both, key=lambda p: ptsClient._avServed(client_pre[p]))
             client_addr = ptsClient._serveReg(client_pre[pre])
-            if not client_addr: return Endpoints.EMPTY_MATCH
             server_addr = ptsServer._serveReg(server_pre[pre])
             # assume servers never run out
             client_transport = ptsClient._endpoints[client_addr]
diff --git a/facilitator/facilitator-test b/facilitator/facilitator-test
index 6709221..3f2fbef 100755
--- a/facilitator/facilitator-test
+++ b/facilitator/facilitator-test
@@ -80,7 +80,7 @@ class EndpointsTest(unittest.TestCase):
         self.assertTrue(self.pts.supports("obfs3|unknownwhat"))
         self.assertFalse(self.pts.supports("obfs2|unknownwhat"))
 
-    def _test_serveReg_maxserve_infinite_roundrobin(self):
+    def test_serveReg_maxserve_infinite_roundrobin(self):
         # case for servers, they never exhaust
         self.pts.addEndpoint("A", "a|p")
         self.pts.addEndpoint("B", "a|p")
@@ -92,7 +92,7 @@ class EndpointsTest(unittest.TestCase):
             served.add(self.pts._serveReg("ABC"))
             self.assertEquals(served, set("ABC"))
 
-    def _test_serveReg_maxserve_finite_exhaustion(self):
+    def test_serveReg_maxserve_finite_exhaustion(self):
         # case for clients, we don't want to keep serving them
         self.pts = Endpoints(af=socket.AF_INET, maxserve=5)
         self.pts.addEndpoint("A", "a|p")
@@ -100,13 +100,28 @@ class EndpointsTest(unittest.TestCase):
         self.pts.addEndpoint("C", "a|p")
         # test getNumUnservedEndpoints whilst we're at it
         self.assertEquals(self.pts.getNumUnservedEndpoints(), 3)
-        for i in xrange(5):
+        served = set()
+        served.add(self.pts._serveReg("ABC"))
+        self.assertEquals(self.pts.getNumUnservedEndpoints(), 2)
+        served.add(self.pts._serveReg("ABC"))
+        self.assertEquals(self.pts.getNumUnservedEndpoints(), 1)
+        served.add(self.pts._serveReg("ABC"))
+        self.assertEquals(self.pts.getNumUnservedEndpoints(), 0)
+        self.assertEquals(served, set("ABC"))
+        for i in xrange(5-2):
             served = set()
             served.add(self.pts._serveReg("ABC"))
             served.add(self.pts._serveReg("ABC"))
             served.add(self.pts._serveReg("ABC"))
             self.assertEquals(served, set("ABC"))
-        self.assertEquals(None, self.pts._serveReg("ABC"))
+        remaining = set("ABC")
+        remaining.remove(self.pts._serveReg(remaining))
+        self.assertRaises(KeyError, self.pts._serveReg, "ABC")
+        remaining.remove(self.pts._serveReg(remaining))
+        self.assertRaises(KeyError, self.pts._serveReg, "ABC")
+        remaining.remove(self.pts._serveReg(remaining))
+        self.assertRaises(KeyError, self.pts._serveReg, "ABC")
+        self.assertEquals(remaining, set())
         self.assertEquals(self.pts.getNumUnservedEndpoints(), 0)
 
     def test_match_normal(self):





More information about the tor-commits mailing list