commit e027db3d3348944c7f75cda52c41b020f02b2aba Author: Ximin Luo infinity0@gmx.com Date: Fri Oct 11 15:21:14 2013 +0100
simplify Endpoints a bit - remove Endpoints.supports() and related code - the important thing actually is whether a proxy supports a transport, not whether we have relays that support it - use defaultdict to get rid of some boilerplate, and populate _indexes unconditionally --- facilitator/facilitator | 89 +++++++++++++++++------------------------- facilitator/facilitator-test | 61 ++++++++--------------------- 2 files changed, 53 insertions(+), 97 deletions(-)
diff --git a/facilitator/facilitator b/facilitator/facilitator index d011013..9e11826 100755 --- a/facilitator/facilitator +++ b/facilitator/facilitator @@ -8,6 +8,7 @@ import sys import threading import time import traceback +from collections import defaultdict
import fac from fac import Transport, Endpoint @@ -26,6 +27,7 @@ CLIENT_TIMEOUT = 1.0 READLINE_MAX_LENGTH = 10240
MAX_PROXIES_PER_CLIENT = 5 +DEFAULT_OUTER_TRANSPORTS = ["websocket"]
LOG_DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
@@ -45,6 +47,7 @@ class options(object): pid_filename = None privdrop_username = None safe_logging = True + outer_transports = DEFAULT_OUTER_TRANSPORTS
def usage(f = sys.stdout): print >> f, """\ @@ -59,11 +62,15 @@ again. Listen on 127.0.0.1 and port PORT (by default %(port)d). --pidfile FILENAME write PID to FILENAME after daemonizing. --privdrop-user USER switch UID and GID to those of USER. -r, --relay-file RELAY learn relays from FILE. + --outer-transports TRANSPORTS + comma-sep list of outer transports to accept proxies + for (by default %(outer-transports)s) --unsafe-logging don't scrub IP addresses from logs.\ """ % { "progname": sys.argv[0], "port": DEFAULT_LISTEN_PORT, "log": DEFAULT_LOG_FILENAME, + "outer-transports": ",".join(DEFAULT_OUTER_TRANSPORTS) }
def safe_str(s): @@ -87,16 +94,13 @@ class Endpoints(object):
matchingLock = threading.Condition()
- def __init__(self, af, maxserve=float("inf"), known_outer=("websocket",)): + def __init__(self, af, maxserve=float("inf")): self.af = af self._maxserve = maxserve self._endpoints = {} # address -> transport - self._indexes = {} # outer -> [ addresses ] + self._indexes = defaultdict(lambda: defaultdict(set)) # outer -> inner -> [ addresses ] self._served = {} # address -> num_times_served self._cv = threading.Condition() - self.known_outer = set(known_outer) - for outer in self.known_outer: - self._ensureIndexForOuter(outer)
def getNumEndpoints(self): """:returns: the number of endpoints known to us.""" @@ -119,9 +123,10 @@ class Endpoints(object): transport = Transport.parse(transport) with self._cv: if addr in self._endpoints: return False + inner, outer = transport self._endpoints[addr] = transport self._served[addr] = 0 - self._addAddrIntoIndexes(addr) + self._indexes[outer][inner].add(addr) self._cv.notify() return True
@@ -133,43 +138,26 @@ class Endpoints(object): """ with self._cv: if addr not in self._endpoints: return False - self._delAddrFromIndexes(addr) + inner, outer = self._endpoints[addr] + self._indexes[outer][inner].remove(addr) # TODO(infinity0): maybe delete empty bins del self._served[addr] del self._endpoints[addr] self._cv.notify() return True
- def supports(self, transport): - """ - Estimate whether we support the given transport. May give false - positives, but doing a proper match later on will catch these. - - :returns: True if we know, or have met, proxies that might be able - to satisfy the requested transport against our known endpoints. - """ - transport = Transport.parse(transport) - with self._cv: - known_inner = self._findInnerForOuter(*self.known_outer).keys() - inner, outer = transport.inner, transport.outer - return inner in known_inner and outer in self.known_outer - def _findInnerForOuter(self, *supported_outer): """ :returns: { inner: [addr] }, where each address supports some outer from supported_outer. TODO(infinity0): describe better """ - self.known_outer.update(supported_outer) - inners = {} - for outer in supported_outer: - self._ensureIndexForOuter(outer) - for addr in self._indexes[outer]: - inner = self._endpoints[addr].inner - inners.setdefault(inner, set()).add(addr) + inners = defaultdict(set) + for outer in set(supported_outer) & set(self._indexes.iterkeys()): + for inner, addrs in self._indexes[outer].iteritems(): + if addrs: + # don't add empty bins, to avoid false-positive key checks + inners[inner].update(addrs) return inners
- def _avServed(self, addrpool): - return sum(self._served[a] for a in addrpool) / float(len(addrpool)) - def _serveReg(self, addrpool): """ :param list addrpool: List of candidate addresses. @@ -186,20 +174,6 @@ class Endpoints(object): self.delEndpoint(prio_addr) return prio_addr
- def _ensureIndexForOuter(self, outer): - if outer in self._indexes: return - addrs = set(addr for addr, transport in self._endpoints.iteritems() - if transport.outer == outer) - self._indexes[outer] = addrs - - def _addAddrIntoIndexes(self, addr): - outer = self._endpoints[addr].outer - if outer in self._indexes: self._indexes[outer].add(addr) - - def _delAddrFromIndexes(self, addr): - outer = self._endpoints[addr].outer - if outer in self._indexes: self._indexes[outer].remove(addr) - EMPTY_MATCH = (None, None) @staticmethod def match(ptsClient, ptsServer, supported_outer): @@ -208,7 +182,9 @@ class Endpoints(object): the available endpoints that can satisfy supported_outer. """ if ptsClient.af != ptsServer.af: - raise ValueError("address family not equal!") + raise ValueError("address family not equal") + if ptsServer._maxserve < float("inf"): + raise ValueError("servers mustn't run out") # need to operate on both structures # so hold both locks plus a pair-wise lock with Endpoints.matchingLock, ptsClient._cv, ptsServer._cv: @@ -216,16 +192,19 @@ class Endpoints(object): client_inner = ptsClient._findInnerForOuter(*supported_outer) both = set(server_inner.keys()) & set(client_inner.keys()) if not both: return Endpoints.EMPTY_MATCH - # pick the inner whose client address pool is least well-served - # TODO: this may be manipulated by clients, needs research - assert all(client_inner.itervalues()) # no pool is empty - inner = min(both, key=lambda p: ptsClient._avServed(client_inner[p])) - client_addr = ptsClient._serveReg(client_inner[inner]) + # find a client to serve + client_pool = [addr for inner in both for addr in client_inner[inner]] + assert len(client_pool) + client_addr = ptsClient._serveReg(client_pool) + # find a server to serve that has the same inner transport + inner = ptsClient._endpoints[client_addr].inner + assert inner in server_inner and len(server_inner[inner]) server_addr = ptsServer._serveReg(server_inner[inner]) # assume servers never run out client_transport = ptsClient._endpoints[client_addr] server_transport = ptsServer._endpoints[server_addr] - return Endpoint(client_addr, client_transport), Endpoint(server_addr, server_transport) + return (Endpoint(client_addr, client_transport), + Endpoint(server_addr, server_transport))
class Handler(SocketServer.StreamRequestHandler): @@ -354,7 +333,7 @@ class Handler(SocketServer.StreamRequestHandler):
transport = Transport.parse(transport) # See if we have relays that support this transport - if all(not pts.supports(transport) for pts in SERVERS.itervalues()): + if transport.outer not in options.outer_transports: return self.error(u"Unrecognized transport: %s" % transport)
client_spec = fac.param_first("CLIENT", params) @@ -445,6 +424,8 @@ def parse_relay_file(servers, fp): raise ValueError("Wrong line format: %s." % repr(line)) addr = fac.parse_addr_spec(addr_spec, defport=DEFAULT_RELAY_PORT, resolve=True) transport = Transport.parse(transport_spec) + if transport.outer not in options.outer_transports: + raise ValueError(u"Unrecognized transport: %s" % transport) af = addr_af(addr[0]) servers[af].addEndpoint(addr, transport)
@@ -468,6 +449,8 @@ def main(): options.privdrop_username = a elif o == "-r" or o == "--relay-file": options.relay_filename = a + elif o == "--outer-transports": + options.outer_transports = a.split(",") elif o == "--unsafe-logging": options.safe_logging = False
diff --git a/facilitator/facilitator-test b/facilitator/facilitator-test index 8e06053..8143348 100755 --- a/facilitator/facilitator-test +++ b/facilitator/facilitator-test @@ -29,56 +29,29 @@ class EndpointsTest(unittest.TestCase): def setUp(self): self.pts = Endpoints(af=socket.AF_INET)
- def _observeProxySupporting(self, *supported_outer): - # semantically observe the existence of a proxy, to make our intent - # a bit clearer than simply calling _findInnerForOuter - self.pts._findInnerForOuter(*supported_outer) - def test_addEndpoints_twice(self): self.pts.addEndpoint("A", "a|b|p") self.assertFalse(self.pts.addEndpoint("A", "zzz")) self.assertEquals(self.pts._endpoints["A"], Transport("a|b", "p"))
- def test_addEndpoints_lazy_indexing(self): + def test_delEndpoints_twice(self): + self.pts.addEndpoint("A", "a|b|p") + self.assertTrue(self.pts.delEndpoint("A")) + self.assertFalse(self.pts.delEndpoint("A")) + self.assertEquals(self.pts._endpoints.get("A"), None) + + def test_Endpoints_indexing(self): + self.assertEquals(self.pts._indexes.get("p"), None) + # test defaultdict works as expected + self.assertEquals(self.pts._indexes["p"]["a|b"], set("")) self.pts.addEndpoint("A", "a|b|p") - default_index = {"websocket": set()} # we always index known_outer - - # no index until we've asked for it - self.assertEquals(self.pts._indexes, default_index) - self._observeProxySupporting("p") - self.assertEquals(self.pts._indexes["p"], set("A")) - - # indexes are updated correctly after observing new addresses - self.pts.addEndpoint("B", "c|p") - self.assertEquals(self.pts._indexes["p"], set("AB")) - - # indexes are updated correctly after observing new proxies - self.pts.addEndpoint("C", "a|q") - self._observeProxySupporting("q") - self.assertEquals(self.pts._indexes["q"], set("C")) - - def test_supports_default(self): - # we know there are websocket-capable proxies out there; - # support them implicitly without needing to see a proxy. - self.pts.addEndpoint("A", "obfs3|websocket") - self.assertTrue(self.pts.supports("obfs3|websocket")) - self.assertFalse(self.pts.supports("xxx|websocket")) - self.assertFalse(self.pts.supports("websocket")) - self.assertFalse(self.pts.supports("unknownwhat")) - # doesn't matter what the first part is - self.pts.addEndpoint("B", "xxx|websocket") - self.assertTrue(self.pts.supports("xxx|websocket")) - - def test_supports_seen_proxy(self): - # OTOH if some 3rd-party proxy decides to implement its own transport - # we are fully capable of supporting them too, but only if we have - # an endpoint that also speaks it. - self.assertFalse(self.pts.supports("obfs3|unknownwhat")) - self._observeProxySupporting("unknownwhat") - self.assertFalse(self.pts.supports("obfs3|unknownwhat")) - self.pts.addEndpoint("A", "obfs3|unknownwhat") - self.assertTrue(self.pts.supports("obfs3|unknownwhat")) - self.assertFalse(self.pts.supports("obfs2|unknownwhat")) + self.assertEquals(self.pts._indexes["p"]["a|b"], set("A")) + self.pts.addEndpoint("B", "a|b|p") + self.assertEquals(self.pts._indexes["p"]["a|b"], set("AB")) + self.pts.delEndpoint("A") + self.assertEquals(self.pts._indexes["p"]["a|b"], set("B")) + self.pts.delEndpoint("B") + self.assertEquals(self.pts._indexes["p"]["a|b"], set(""))
def test_serveReg_maxserve_infinite_roundrobin(self): # case for servers, they never exhaust