[tor-commits] [flashproxy/master] - fix match() trying to access info of deleted address

infinity0 at torproject.org infinity0 at torproject.org
Mon Oct 28 14:47:41 UTC 2013


commit 9b155c9b2029ec779666a661e5eecf8226cef6c8
Author: Ximin Luo <infinity0 at gmx.com>
Date:   Fri Oct 11 17:12:07 2013 +0100

    - fix match() trying to access info of deleted address
---
 facilitator/facilitator      |   16 +++++++---------
 facilitator/facilitator-test |   38 ++++++++++++++++++++++++++------------
 2 files changed, 33 insertions(+), 21 deletions(-)

diff --git a/facilitator/facilitator b/facilitator/facilitator
index 9e11826..844836a 100755
--- a/facilitator/facilitator
+++ b/facilitator/facilitator
@@ -161,7 +161,7 @@ class Endpoints(object):
     def _serveReg(self, addrpool):
         """
         :param list addrpool: List of candidate addresses.
-        :returns: An address of an endpoint from the given pool. The serve
+        :returns: An Endpoint whose address is from the given pool. The serve
             counter for that address is also incremented, and if it hits
             self._maxserve the endpoint is removed from this collection.
         :raises: KeyError if any address is not registered with this collection
@@ -170,9 +170,10 @@ class Endpoints(object):
         prio_addr = min(addrpool, key=lambda a: self._served[a])
         assert self._served[prio_addr] < self._maxserve
         self._served[prio_addr] += 1
+        transport = self._endpoints[prio_addr]
         if self._served[prio_addr] == self._maxserve:
             self.delEndpoint(prio_addr)
-        return prio_addr
+        return Endpoint(prio_addr, transport)
 
     EMPTY_MATCH = (None, None)
     @staticmethod
@@ -195,16 +196,13 @@ class Endpoints(object):
             # find a client to serve
             client_pool = [addr for inner in both for addr in client_inner[inner]]
             assert len(client_pool)
-            client_addr = ptsClient._serveReg(client_pool)
+            client_reg = ptsClient._serveReg(client_pool)
             # find a server to serve that has the same inner transport
-            inner = ptsClient._endpoints[client_addr].inner
+            inner = client_reg.transport.inner
             assert inner in server_inner and len(server_inner[inner])
-            server_addr = ptsServer._serveReg(server_inner[inner])
+            server_reg = ptsServer._serveReg(server_inner[inner])
             # assume servers never run out
-            client_transport = ptsClient._endpoints[client_addr]
-            server_transport = ptsServer._endpoints[server_addr]
-            return (Endpoint(client_addr, client_transport),
-                    Endpoint(server_addr, server_transport))
+            return (client_reg, server_reg)
 
 
 class Handler(SocketServer.StreamRequestHandler):
diff --git a/facilitator/facilitator-test b/facilitator/facilitator-test
index 8143348..3efe34d 100755
--- a/facilitator/facilitator-test
+++ b/facilitator/facilitator-test
@@ -60,9 +60,9 @@ class EndpointsTest(unittest.TestCase):
         self.pts.addEndpoint("C", "a|p")
         for i in xrange(64): # 64 is infinite ;)
             served = set()
-            served.add(self.pts._serveReg("ABC"))
-            served.add(self.pts._serveReg("ABC"))
-            served.add(self.pts._serveReg("ABC"))
+            served.add(self.pts._serveReg("ABC").addr)
+            served.add(self.pts._serveReg("ABC").addr)
+            served.add(self.pts._serveReg("ABC").addr)
             self.assertEquals(served, set("ABC"))
 
     def test_serveReg_maxserve_finite_exhaustion(self):
@@ -74,25 +74,25 @@ class EndpointsTest(unittest.TestCase):
         # test getNumUnservedEndpoints whilst we're at it
         self.assertEquals(self.pts.getNumUnservedEndpoints(), 3)
         served = set()
-        served.add(self.pts._serveReg("ABC"))
+        served.add(self.pts._serveReg("ABC").addr)
         self.assertEquals(self.pts.getNumUnservedEndpoints(), 2)
-        served.add(self.pts._serveReg("ABC"))
+        served.add(self.pts._serveReg("ABC").addr)
         self.assertEquals(self.pts.getNumUnservedEndpoints(), 1)
-        served.add(self.pts._serveReg("ABC"))
+        served.add(self.pts._serveReg("ABC").addr)
         self.assertEquals(self.pts.getNumUnservedEndpoints(), 0)
         self.assertEquals(served, set("ABC"))
         for i in xrange(5-2):
             served = set()
-            served.add(self.pts._serveReg("ABC"))
-            served.add(self.pts._serveReg("ABC"))
-            served.add(self.pts._serveReg("ABC"))
+            served.add(self.pts._serveReg("ABC").addr)
+            served.add(self.pts._serveReg("ABC").addr)
+            served.add(self.pts._serveReg("ABC").addr)
             self.assertEquals(served, set("ABC"))
         remaining = set("ABC")
-        remaining.remove(self.pts._serveReg(remaining))
+        remaining.remove(self.pts._serveReg(remaining).addr)
         self.assertRaises(KeyError, self.pts._serveReg, "ABC")
-        remaining.remove(self.pts._serveReg(remaining))
+        remaining.remove(self.pts._serveReg(remaining).addr)
         self.assertRaises(KeyError, self.pts._serveReg, "ABC")
-        remaining.remove(self.pts._serveReg(remaining))
+        remaining.remove(self.pts._serveReg(remaining).addr)
         self.assertRaises(KeyError, self.pts._serveReg, "ABC")
         self.assertEquals(remaining, set())
         self.assertEquals(self.pts.getNumUnservedEndpoints(), 0)
@@ -151,6 +151,20 @@ class EndpointsTest(unittest.TestCase):
         self.assertEquals(empty, Endpoints.match(self.pts, self.pts2, ["x"]))
         self.assertEquals(empty, Endpoints.match(self.pts, self.pts2, ["x"]))
 
+    def test_match_exhaustion(self):
+        self.pts.addEndpoint("A", "p")
+        self.pts2 = Endpoints(af=socket.AF_INET, maxserve=2)
+        self.pts2.addEndpoint("B", "p")
+        print self.pts2._indexes, self.pts2._served
+        Endpoints.match(self.pts2, self.pts, ["p"])
+        print self.pts2._indexes, self.pts2._served
+        Endpoints.match(self.pts2, self.pts, ["p"])
+        empty = Endpoints.EMPTY_MATCH
+        self.assertTrue("B" not in self.pts2._endpoints)
+        self.assertTrue("B" not in self.pts2._indexes["p"][""])
+        self.assertEquals(empty, Endpoints.match(self.pts2, self.pts, ["p"]))
+
+
 class FacilitatorTest(unittest.TestCase):
 
     def test_transport_parse(self):





More information about the tor-commits mailing list