[flashproxy/master] drop registrations that hit _maxserve, and fix tests to run properly

commit d658db1f29d698a7820ef0d4bdf1120e01fb35af Author: Ximin Luo <infinity0@gmx.com> Date: Wed Oct 9 23:49:38 2013 +0100 drop registrations that hit _maxserve, and fix tests to run properly - the old behaviour was based on an incorrect understanding of the previous iteration of the code --- facilitator/facilitator | 20 ++++++++++---------- facilitator/facilitator-test | 23 +++++++++++++++++++---- 2 files changed, 29 insertions(+), 14 deletions(-) diff --git a/facilitator/facilitator b/facilitator/facilitator index f3ae79b..07038d1 100755 --- a/facilitator/facilitator +++ b/facilitator/facilitator @@ -173,17 +173,18 @@ class Endpoints(object): def _serveReg(self, addrpool): """ :param list addrpool: List of candidate addresses. - :returns: An address of an endpoint from the given pool, or None if all - endpoints have already been served _maxserve times. The serve - counter for that address is also incremented. + :returns: An address of an endpoint from the given pool. The serve + counter for that address is also incremented, and if it hits + self._maxserve the endpoint is removed from this collection. + :raises: KeyError if any address is not registered with this collection """ - if not addrpool: return None + if not addrpool: raise ValueError("gave empty address pool") prio_addr = min(addrpool, key=lambda a: self._served[a]) - if self._served[prio_addr] < self._maxserve: - self._served[prio_addr] += 1 - return prio_addr - else: - return None + assert self._served[prio_addr] < self._maxserve + self._served[prio_addr] += 1 + if self._served[prio_addr] == self._maxserve: + self.delEndpoint(prio_addr) + return prio_addr def _ensureIndexForSuffix(self, suf): if suf in self._indexes: return @@ -227,7 +228,6 @@ class Endpoints(object): assert all(client_pre.itervalues()) # no pool is empty pre = min(both, key=lambda p: ptsClient._avServed(client_pre[p])) client_addr = ptsClient._serveReg(client_pre[pre]) - if not client_addr: return Endpoints.EMPTY_MATCH server_addr = ptsServer._serveReg(server_pre[pre]) # assume servers never run out client_transport = ptsClient._endpoints[client_addr] diff --git a/facilitator/facilitator-test b/facilitator/facilitator-test index 6709221..3f2fbef 100755 --- a/facilitator/facilitator-test +++ b/facilitator/facilitator-test @@ -80,7 +80,7 @@ class EndpointsTest(unittest.TestCase): self.assertTrue(self.pts.supports("obfs3|unknownwhat")) self.assertFalse(self.pts.supports("obfs2|unknownwhat")) - def _test_serveReg_maxserve_infinite_roundrobin(self): + def test_serveReg_maxserve_infinite_roundrobin(self): # case for servers, they never exhaust self.pts.addEndpoint("A", "a|p") self.pts.addEndpoint("B", "a|p") @@ -92,7 +92,7 @@ class EndpointsTest(unittest.TestCase): served.add(self.pts._serveReg("ABC")) self.assertEquals(served, set("ABC")) - def _test_serveReg_maxserve_finite_exhaustion(self): + def test_serveReg_maxserve_finite_exhaustion(self): # case for clients, we don't want to keep serving them self.pts = Endpoints(af=socket.AF_INET, maxserve=5) self.pts.addEndpoint("A", "a|p") @@ -100,13 +100,28 @@ class EndpointsTest(unittest.TestCase): self.pts.addEndpoint("C", "a|p") # test getNumUnservedEndpoints whilst we're at it self.assertEquals(self.pts.getNumUnservedEndpoints(), 3) - for i in xrange(5): + served = set() + served.add(self.pts._serveReg("ABC")) + self.assertEquals(self.pts.getNumUnservedEndpoints(), 2) + served.add(self.pts._serveReg("ABC")) + self.assertEquals(self.pts.getNumUnservedEndpoints(), 1) + served.add(self.pts._serveReg("ABC")) + self.assertEquals(self.pts.getNumUnservedEndpoints(), 0) + self.assertEquals(served, set("ABC")) + for i in xrange(5-2): served = set() served.add(self.pts._serveReg("ABC")) served.add(self.pts._serveReg("ABC")) served.add(self.pts._serveReg("ABC")) self.assertEquals(served, set("ABC")) - self.assertEquals(None, self.pts._serveReg("ABC")) + remaining = set("ABC") + remaining.remove(self.pts._serveReg(remaining)) + self.assertRaises(KeyError, self.pts._serveReg, "ABC") + remaining.remove(self.pts._serveReg(remaining)) + self.assertRaises(KeyError, self.pts._serveReg, "ABC") + remaining.remove(self.pts._serveReg(remaining)) + self.assertRaises(KeyError, self.pts._serveReg, "ABC") + self.assertEquals(remaining, set()) self.assertEquals(self.pts.getNumUnservedEndpoints(), 0) def test_match_normal(self):
participants (1)
-
infinity0@torproject.org