[flashproxy/master] - fix match() trying to access info of deleted address

commit 9b155c9b2029ec779666a661e5eecf8226cef6c8 Author: Ximin Luo <infinity0@gmx.com> Date: Fri Oct 11 17:12:07 2013 +0100 - fix match() trying to access info of deleted address --- facilitator/facilitator | 16 +++++++--------- facilitator/facilitator-test | 38 ++++++++++++++++++++++++++------------ 2 files changed, 33 insertions(+), 21 deletions(-) diff --git a/facilitator/facilitator b/facilitator/facilitator index 9e11826..844836a 100755 --- a/facilitator/facilitator +++ b/facilitator/facilitator @@ -161,7 +161,7 @@ class Endpoints(object): def _serveReg(self, addrpool): """ :param list addrpool: List of candidate addresses. - :returns: An address of an endpoint from the given pool. The serve + :returns: An Endpoint whose address is from the given pool. The serve counter for that address is also incremented, and if it hits self._maxserve the endpoint is removed from this collection. :raises: KeyError if any address is not registered with this collection @@ -170,9 +170,10 @@ class Endpoints(object): prio_addr = min(addrpool, key=lambda a: self._served[a]) assert self._served[prio_addr] < self._maxserve self._served[prio_addr] += 1 + transport = self._endpoints[prio_addr] if self._served[prio_addr] == self._maxserve: self.delEndpoint(prio_addr) - return prio_addr + return Endpoint(prio_addr, transport) EMPTY_MATCH = (None, None) @staticmethod @@ -195,16 +196,13 @@ class Endpoints(object): # find a client to serve client_pool = [addr for inner in both for addr in client_inner[inner]] assert len(client_pool) - client_addr = ptsClient._serveReg(client_pool) + client_reg = ptsClient._serveReg(client_pool) # find a server to serve that has the same inner transport - inner = ptsClient._endpoints[client_addr].inner + inner = client_reg.transport.inner assert inner in server_inner and len(server_inner[inner]) - server_addr = ptsServer._serveReg(server_inner[inner]) + server_reg = ptsServer._serveReg(server_inner[inner]) # assume servers never run out - client_transport = ptsClient._endpoints[client_addr] - server_transport = ptsServer._endpoints[server_addr] - return (Endpoint(client_addr, client_transport), - Endpoint(server_addr, server_transport)) + return (client_reg, server_reg) class Handler(SocketServer.StreamRequestHandler): diff --git a/facilitator/facilitator-test b/facilitator/facilitator-test index 8143348..3efe34d 100755 --- a/facilitator/facilitator-test +++ b/facilitator/facilitator-test @@ -60,9 +60,9 @@ class EndpointsTest(unittest.TestCase): self.pts.addEndpoint("C", "a|p") for i in xrange(64): # 64 is infinite ;) served = set() - served.add(self.pts._serveReg("ABC")) - served.add(self.pts._serveReg("ABC")) - served.add(self.pts._serveReg("ABC")) + served.add(self.pts._serveReg("ABC").addr) + served.add(self.pts._serveReg("ABC").addr) + served.add(self.pts._serveReg("ABC").addr) self.assertEquals(served, set("ABC")) def test_serveReg_maxserve_finite_exhaustion(self): @@ -74,25 +74,25 @@ class EndpointsTest(unittest.TestCase): # test getNumUnservedEndpoints whilst we're at it self.assertEquals(self.pts.getNumUnservedEndpoints(), 3) served = set() - served.add(self.pts._serveReg("ABC")) + served.add(self.pts._serveReg("ABC").addr) self.assertEquals(self.pts.getNumUnservedEndpoints(), 2) - served.add(self.pts._serveReg("ABC")) + served.add(self.pts._serveReg("ABC").addr) self.assertEquals(self.pts.getNumUnservedEndpoints(), 1) - served.add(self.pts._serveReg("ABC")) + served.add(self.pts._serveReg("ABC").addr) self.assertEquals(self.pts.getNumUnservedEndpoints(), 0) self.assertEquals(served, set("ABC")) for i in xrange(5-2): served = set() - served.add(self.pts._serveReg("ABC")) - served.add(self.pts._serveReg("ABC")) - served.add(self.pts._serveReg("ABC")) + served.add(self.pts._serveReg("ABC").addr) + served.add(self.pts._serveReg("ABC").addr) + served.add(self.pts._serveReg("ABC").addr) self.assertEquals(served, set("ABC")) remaining = set("ABC") - remaining.remove(self.pts._serveReg(remaining)) + remaining.remove(self.pts._serveReg(remaining).addr) self.assertRaises(KeyError, self.pts._serveReg, "ABC") - remaining.remove(self.pts._serveReg(remaining)) + remaining.remove(self.pts._serveReg(remaining).addr) self.assertRaises(KeyError, self.pts._serveReg, "ABC") - remaining.remove(self.pts._serveReg(remaining)) + remaining.remove(self.pts._serveReg(remaining).addr) self.assertRaises(KeyError, self.pts._serveReg, "ABC") self.assertEquals(remaining, set()) self.assertEquals(self.pts.getNumUnservedEndpoints(), 0) @@ -151,6 +151,20 @@ class EndpointsTest(unittest.TestCase): self.assertEquals(empty, Endpoints.match(self.pts, self.pts2, ["x"])) self.assertEquals(empty, Endpoints.match(self.pts, self.pts2, ["x"])) + def test_match_exhaustion(self): + self.pts.addEndpoint("A", "p") + self.pts2 = Endpoints(af=socket.AF_INET, maxserve=2) + self.pts2.addEndpoint("B", "p") + print self.pts2._indexes, self.pts2._served + Endpoints.match(self.pts2, self.pts, ["p"]) + print self.pts2._indexes, self.pts2._served + Endpoints.match(self.pts2, self.pts, ["p"]) + empty = Endpoints.EMPTY_MATCH + self.assertTrue("B" not in self.pts2._endpoints) + self.assertTrue("B" not in self.pts2._indexes["p"][""]) + self.assertEquals(empty, Endpoints.match(self.pts2, self.pts, ["p"])) + + class FacilitatorTest(unittest.TestCase): def test_transport_parse(self):
participants (1)
-
infinity0@torproject.org