commit 2658f6e3eac20a4f7c20e2c6b90ddf253731dce1 Author: Ximin Luo infinity0@gmx.com Date: Mon Oct 7 16:47:53 2013 +0100
Reimplement matching algorithm using an Endpoints datastruct for both client/server - select prefix pool based on least-well-served rather than arbitrarily - don't attempt a match until a proxy is available to service the request - don't match ipv6 proxies to ipv4 servers --- facilitator/facilitator | 428 +++++++++++++++++++++++------------------- facilitator/facilitator-test | 166 +++++++++++++++- 2 files changed, 396 insertions(+), 198 deletions(-)
diff --git a/facilitator/facilitator b/facilitator/facilitator index e44047e..d4088ff 100755 --- a/facilitator/facilitator +++ b/facilitator/facilitator @@ -8,6 +8,7 @@ import sys import threading import time import traceback +from collections import namedtuple
import fac
@@ -65,26 +66,6 @@ again. Listen on 127.0.0.1 and port PORT (by default %(port)d). "log": DEFAULT_LOG_FILENAME, }
-def num_relays(): - return sum(len(x) for x in RELAYS.values()) - -def parse_transport_chain(spec): - """Parse a transport chain string and return a tuple of individual - transports, each of which is a string. - >>> parse_transport_chain("obfs3|websocket") - ('obfs3', 'websocket') - """ - assert(spec) - return tuple(spec.split("|")) - -def get_outermost_transport(transports): - """Given a transport chain tuple, return the last element. - >>> get_outermost_transport(("obfs3", "websocket")) - 'websocket' - """ - assert(transports) - return transports[-1] - def safe_str(s): """Return "[scrubbed]" if options.safe_logging is true, and s otherwise.""" if options.safe_logging: @@ -98,87 +79,191 @@ def log(msg): print >> options.log_file, (u"%s %s" % (time.strftime(LOG_DATE_FORMAT), msg)).encode("UTF-8") options.log_file.flush()
-class TCPReg(object): - def __init__(self, host, port, transports): - self.host = host - self.port = port - self.transports = transports - # Get a relay for this registration. Throw UnknownTransport if - # could not be found. - self.relay = self._get_matching_relay()
- def _get_matching_relay(self): - """Return a matching relay address for this registration. Raise - UnknownTransport if a relay with a matching transport chain could not be - found.""" - if self.transports not in RELAYS: - raise UnknownTransport("Can't find relay with transport chain: %s" % self.transports) - - # Maybe this should be a random pick from the set of all the - # eligible relays. But let's keep it deterministic for now, - # and return the first one. - - # return random.choice(RELAYS[self.transports]) - return RELAYS[self.transports][0] +class Transport(namedtuple("Transport", "prefix suffix")): + @classmethod + def parse(cls, transport): + if isinstance(transport, cls): + return transport + elif type(transport) == str: + if "|" in transport: + prefix, suffix = transport.rsplit("|", 1) + else: + prefix, suffix = "", transport + return cls(prefix, suffix) + else: + raise ValueError("could not parse transport: %s" % transport)
- def __unicode__(self): - return fac.format_addr((self.host, self.port)) + def __init__(self, prefix, suffix): + if not suffix: + raise ValueError("suffix (proxy) part of transport must be non-empty: %s" % str(self))
def __str__(self): - return unicode(self).encode("UTF-8") + return "%s|%s" % (self.prefix, self.suffix) if self.prefix else self.suffix
- def __cmp__(self, other): - if isinstance(other, TCPReg): - # XXX is this correct comparison? - return cmp((self.host, self.port, self.transports), (other.host, other.port, other.transports)) - else: - return False
-class Reg(object): - @staticmethod - def parse(spec, transports, defhost = None, defport = None): +class Reg(namedtuple("Reg", "addr transport")): + @classmethod + def parse(cls, spec, transport, defhost = None, defport = None): host, port = fac.parse_addr_spec(spec, defhost, defport) - return TCPReg(host, port, transports) + return cls((host, port), Transport.parse(transport))
-class RegSet(object): - def __init__(self): - self.tiers = [[] for i in range(MAX_PROXIES_PER_CLIENT)] - self.cv = threading.Condition()
- def add(self, reg): - self.cv.acquire() - try: - for tier in self.tiers: - if reg in tier: - break - else: - self.tiers[0].append(reg) - self.cv.notify() - return True - return False - finally: - self.cv.release() - - def fetch(self): - self.cv.acquire() - try: - for i in range(len(self.tiers)): - tier = self.tiers[i] - if tier: - reg = tier.pop(0) - if i + 1 < len(self.tiers): - self.tiers[i+1].append(reg) - return reg +class Endpoints(object): + """ + Tracks endpoints (either client/server) and the transport chains that + they support. + """ + + matchingLock = threading.Condition() + + def __init__(self, af, maxserve=float("inf"), known_suf=("websocket",)): + self.af = af + self._maxserve = maxserve + self._endpoints = {} # address -> transport + self._indexes = {} # suffix -> [ addresses ] + self._served = {} # address -> num_times_served + self._cv = threading.Condition() + self.known_suf = set(known_suf) + for suf in self.known_suf: + self._ensureIndexForSuffix(suf) + + def getNumEndpoints(self): + """:returns: the number of endpoints known to us.""" + with self._cv: + return len(self._endpoints) + + def getNumUnservedEndpoints(self): + """:returns: the number of unserved endpoints known to us.""" + with self._cv: + return len(filter(lambda t: t == 0, self._served.itervalues())) + + def addEndpoint(self, addr, transport): + """Add an endpoint. + + :param addr: Address of endpoint, usage-dependent. + :param list transports: List of transports. + :returns: False if the address is already known, in which case no + update is made to its supported transports, else True. + """ + transport = Transport.parse(transport) + with self._cv: + if addr in self._endpoints: return False + self._endpoints[addr] = transport + self._served[addr] = 0 + self._addAddrIntoIndexes(addr) + self._cv.notify() + return True + + def delEndpoint(self, addr): + """Forget an endpoint. + + :param addr: Address of endpoint, usage-dependent. + :returns: False if the address was already forgotten, else True. + """ + with self._cv: + if addr not in self._endpoints: return False + self._delAddrFromIndexes(addr) + del self._served[addr] + del self._endpoints[addr] + self._cv.notify() + return True + + def supports(self, transport): + """ + Estimate whether we support the given transport. May give false + positives, but doing a proper match later on will catch these. + + :returns: True if we know, or have met, proxies that might be able + to satisfy the requested transport against our known endpoints. + """ + transport = Transport.parse(transport) + with self._cv: + known_pre = self._findPrefixesForSuffixes(*self.known_suf).keys() + pre, suf = transport.prefix, transport.suffix + return pre in known_pre and suf in self.known_suf + + def _findPrefixesForSuffixes(self, *supported_suf): + """ + :returns: { prefix: [addr] }, where each address supports some suffix + from supported_suf. TODO(infinity0): describe better + """ + self.known_suf.update(supported_suf) + prefixes = {} + for suf in supported_suf: + self._ensureIndexForSuffix(suf) + for addr in self._indexes[suf]: + pre = self._endpoints[addr].prefix + prefixes.setdefault(pre, set()).add(addr) + return prefixes + + def _avServed(self, addrpool): + return sum(self._served[a] for a in addrpool) / float(len(addrpool)) + + def _serveReg(self, addrpool): + """ + :param list addrpool: List of candidate addresses. + :returns: An address of an endpoint from the given pool, or None if all + endpoints have already been served _maxserve times. The serve + counter for that address is also incremented. + """ + if not addrpool: return None + prio_addr = min(addrpool, key=lambda a: self._served[a]) + if self._served[prio_addr] < self._maxserve: + self._served[prio_addr] += 1 + return prio_addr + else: return None - finally: - self.cv.release()
- def __len__(self): - self.cv.acquire() - try: - return sum(len(tier) for tier in self.tiers) - finally: - self.cv.release() + def _ensureIndexForSuffix(self, suf): + if suf in self._indexes: return + addrs = set(addr for addr, transport in self._endpoints.iteritems() + if transport.suffix == suf) + self._indexes[suf] = addrs + + def _addAddrIntoIndexes(self, addr): + suf = self._endpoints[addr].suffix + if suf in self._indexes: self._indexes[suf].add(addr) + + def _delAddrFromIndexes(self, addr): + suf = self._endpoints[addr].suffix + if suf in self._indexes: self._indexes[suf].remove(addr) + + def _prefixesForTransport(self, transport, *supported_suf): + for suf in supported_suf: + if not suf: + yield transport + elif transport[-len(suf):] == suf: + yield transport[:-len(suf)] + + EMPTY_MATCH = (None, None) + @staticmethod + def match(ptsClient, ptsServer, supported_suf): + """ + :returns: A tuple (client Reg, server Reg) arbitrarily selected from + the available endpoints that can satisfy supported_suf. + """ + if ptsClient.af != ptsServer.af: + raise ValueError("address family not equal!") + # need to operate on both structures + # so hold both locks plus a pair-wise lock + with Endpoints.matchingLock, ptsClient._cv, ptsServer._cv: + server_pre = ptsServer._findPrefixesForSuffixes(*supported_suf) + client_pre = ptsClient._findPrefixesForSuffixes(*supported_suf) + both = set(server_pre.keys()) & set(client_pre.keys()) + if not both: return Endpoints.EMPTY_MATCH + # pick the prefix whose client address pool is least well-served + # TODO: this may be manipulated by clients, needs research + assert all(client_pre.itervalues()) # no pool is empty + pre = min(both, key=lambda p: ptsClient._avServed(client_pre[p])) + client_addr = ptsClient._serveReg(client_pre[pre]) + if not client_addr: return Endpoints.EMPTY_MATCH + server_addr = ptsServer._serveReg(server_pre[pre]) + # assume servers never run out + client_transport = ptsClient._endpoints[client_addr] + server_transport = ptsServer._endpoints[server_addr] + return Reg(client_addr, client_transport), Reg(server_addr, server_transport) +
class Handler(SocketServer.StreamRequestHandler): def __init__(self, *args, **kwargs): @@ -277,16 +362,19 @@ class Handler(SocketServer.StreamRequestHandler): return self.error(u"TRANSPORT missing FROM param")
try: - reg = get_reg_for_proxy(proxy_addr, transport_list) + client_reg, relay_reg = get_match_for_proxy(proxy_addr, transport_list) except Exception as e: return self.error(u"error getting match for proxy address %s: %%(cause)s" % safe_str(repr(proxy_spec)), e)
check_back_in = get_check_back_in_for_proxy(proxy_addr)
- if reg: - log(u"proxy (%s) gets client '%s' (transports: %s) (num relays: %s) (remaining regs: %d/%d)" % - (safe_str(repr(proxy_spec)), safe_str(unicode(reg)), reg.transports, num_relays(), num_unhandled_regs(), num_regs())) - print >> self.wfile, fac.render_transaction("OK", ("CLIENT", str(reg)), ("RELAY", fac.format_addr(reg.relay)), ("CHECK-BACK-IN", str(check_back_in))) + if client_reg: + log(u"proxy (%s) gets client '%s' (supported transports: %s) (num relays: %s) (remaining regs: %d/%d)" % + (safe_str(repr(proxy_spec)), safe_str(repr(client_reg.addr)), transport_list, num_relays(), num_unhandled_regs(), num_regs())) + print >> self.wfile, fac.render_transaction("OK", + ("CLIENT", fac.format_addr(client_reg.addr)), + ("RELAY", fac.format_addr(relay_reg.addr)), + ("CHECK-BACK-IN", str(check_back_in))) else: log(u"proxy (%s) gets none" % safe_str(repr(proxy_spec))) print >> self.wfile, fac.render_transaction("NONE", ("CHECK-BACK-IN", str(check_back_in))) @@ -297,27 +385,21 @@ class Handler(SocketServer.StreamRequestHandler): # Example: PUT CLIENT="1.1.1.1:5555" TRANSPORT_CHAIN="obfs3|websocket" def do_PUT(self, params): # Check out if we recognize the transport chain in this registration request - transports_spec = fac.param_first("TRANSPORT_CHAIN", params) - if transports_spec is None: + transport = fac.param_first("TRANSPORT_CHAIN", params) + if transport is None: return self.error(u"PUT missing TRANSPORT_CHAIN param")
- transports = parse_transport_chain(transports_spec) - + transport = Transport.parse(transport) # See if we have relays that support this transport chain - if transports not in RELAYS: - log(u"Unrecognized transport chain: %s" % transports) - self.send_error() # XXX can we tell the flashproxy client of this error? - return False - # if we have relays that support this transport chain, we - # certainly have a regset for its outermost transport too. - assert(get_outermost_transport(transports) in REGSETS_IPV4) + if all(not pts.supports(transport) for pts in SERVERS.itervalues()): + return self.error(u"Unrecognized transport: %s" % transport)
client_spec = fac.param_first("CLIENT", params) if client_spec is None: return self.error(u"PUT missing CLIENT param")
try: - reg = Reg.parse(client_spec, transports) + reg = Reg.parse(client_spec, transport) except (UnknownTransport, ValueError) as e: # XXX should we throw a better error message to the client? Is it possible? return self.error(u"syntax error in %s: %%(cause)s" % safe_str(repr(client_spec)), e) @@ -328,9 +410,9 @@ class Handler(SocketServer.StreamRequestHandler): return self.error(u"error putting reg %s: %%(cause)s" % safe_str(repr(client_spec)), e)
if ok: - log(u"client %s (transports: %s) (remaining regs: %d/%d)" % (safe_str(unicode(reg)), reg.transports, num_unhandled_regs(), num_regs())) + log(u"client %s (transports: %s) (remaining regs: %d/%d)" % (safe_str(unicode(reg)), reg.transport, num_unhandled_regs(), num_regs())) else: - log(u"client %s (already present) (transports: %s) (remaining regs: %d/%d)" % (safe_str(unicode(reg)), reg.transports, num_unhandled_regs(), num_regs())) + log(u"client %s (already present) (transports: %s) (remaining regs: %d/%d)" % (safe_str(unicode(reg)), reg.transport, num_unhandled_regs(), num_regs()))
self.send_ok() return True @@ -341,48 +423,29 @@ class Server(SocketServer.ThreadingMixIn, SocketServer.TCPServer): allow_reuse_address = True
# Registration sets per-outermost-transport -# {"websocket" : <RegSet for websocket>, "webrtc" : <RegSet for webrtc>} -REGSETS_IPV4 = {} -REGSETS_IPV6 = {} +# Addresses are plain tuples (str(host), int(port))
-def num_regs(): - """Return the total number of registrations.""" - num_regs = 0 +CLIENTS = { + socket.AF_INET: Endpoints(af=socket.AF_INET, maxserve=MAX_PROXIES_PER_CLIENT), + socket.AF_INET6: Endpoints(af=socket.AF_INET6, maxserve=MAX_PROXIES_PER_CLIENT) +} + +SERVERS = { + socket.AF_INET: Endpoints(af=socket.AF_INET), + socket.AF_INET6: Endpoints(af=socket.AF_INET6) +}
- # Iterate the regsets of each regset-dictionary, and count their - # registrations. - for regset in REGSETS_IPV4.values(): - num_regs += len(regset) - for regset in REGSETS_IPV6.values(): - num_regs += len(regset) +def num_relays(): + """Return the total number of relays.""" + return sum(pts.getNumEndpoints() for pts in SERVERS.itervalues())
- return num_regs +def num_regs(): + """Return the total number of registrations.""" + return sum(pts.getNumEndpoints() for pts in CLIENTS.itervalues())
def num_unhandled_regs(): """Return the total number of unhandled registrations.""" - num_regs = 0 - - # Iterate the regsets of each regset-dictionary, and count their - # unhandled registrations. The first tier of each regset contains - # the registrations with no assigned proxy. - for regset in REGSETS_IPV4.values(): - num_regs += len(regset.tiers[0]) - for regset in REGSETS_IPV6.values(): - num_regs += len(regset.tiers[0]) - - return num_regs - -def get_regs(af, transport): - """Return the correct regs pool for the given address family and transport.""" - if transport not in REGSETS_IPV4: - raise UnknownTransport("unknown transport '%s'" % transport) - - if af == socket.AF_INET: - return REGSETS_IPV4[transport] - elif af == socket.AF_INET6: - return REGSETS_IPV6[transport] - else: - raise ValueError("unknown address family %d" % af) + return sum(pts.getNumUnservedEndpoints() for pts in CLIENTS.itervalues())
def addr_af(addr_str): """Return the address family for an address string. This is a plain string, @@ -390,26 +453,12 @@ def addr_af(addr_str): addrs = socket.getaddrinfo(addr_str, 0, 0, socket.SOCK_STREAM, socket.IPPROTO_TCP, socket.AI_NUMERICHOST) return addrs[0][0]
-def get_reg_for_proxy(proxy_addr, transport_list): - """Get a client registration appropriate for the given proxy (one - of a matching address family). If 'transports' is set, try to find - a client registration that supports the outermost transport of a - transport chain.""" - # XXX How should we prioritize transport matching? We currently - # just iterate the transport list that was provided by the flashproxy - for transport in transport_list: - addr_str = proxy_addr[0] - af = addr_af(addr_str) - - try: - REGS = get_regs(af, transport) - except UnknownTransport as e: - log(u"%s" % e) - continue # move to the next transport - - return REGS.fetch() - - raise UnknownTransport("Could not find registration for transport list: %s" % str(transport_list)) +def get_match_for_proxy(proxy_addr, transport_list): + af = addr_af(proxy_addr[0]) + try: + return Endpoints.match(CLIENTS[af], SERVERS[af], transport_list) + except ValueError as e: + raise UnknownTransport("Could not find registration for transport list: %s: %s" % (transport_list, e))
def get_check_back_in_for_proxy(proxy_addr): """Get a CHECK-BACK-IN interval suitable for this proxy.""" @@ -417,29 +466,24 @@ def get_check_back_in_for_proxy(proxy_addr):
def put_reg(reg): """Add a registration.""" - addr_str = reg.host - af = addr_af(addr_str) - REGS = get_regs(af, get_outermost_transport(reg.transports)) - return REGS.add(reg) + af = addr_af(reg.addr[0]) + return CLIENTS[af].addEndpoint(reg.addr, reg.transport)
-def parse_relay_file(filename): +def parse_relay_file(servers, fp): """Parse a file containing Tor relays that we can point proxies to. Throws ValueError on a parsing error. Each line contains a transport chain and an address, for example obfs2|websocket 1.4.6.1:4123 """ - relays = {} - with open(filename) as f: - for line in f: - try: - transport_spec, addr_spec = line.strip().split() - except ValueError, e: - raise ValueError("Wrong line format: %s." % repr(line)) - addr = fac.parse_addr_spec(addr_spec, defport=DEFAULT_RELAY_PORT, resolve=True) - transports = parse_transport_chain(transport_spec) - relays.setdefault(transports, []) - relays[transports].append(addr) - return relays + for line in fp.readlines(): + try: + transport_spec, addr_spec = line.strip().split() + except ValueError, e: + raise ValueError("Wrong line format: %s." % repr(line)) + addr = fac.parse_addr_spec(addr_spec, defport=DEFAULT_RELAY_PORT, resolve=True) + transport = Transport.parse(transport_spec) + af = addr_af(addr[0]) + servers[af].addEndpoint(addr, transport)
def main(): opts, args = getopt.gnu_getopt(sys.argv[1:], "dhl:p:r:", @@ -474,17 +518,12 @@ obfs2|websocket 1.4.6.1:4123\ """ sys.exit(1)
- RELAYS.update(parse_relay_file(options.relay_filename)) - - if not RELAYS: - print >> sys.stderr, u"Warning: no relays configured." - - # Create RegSets for our supported transports - for transport in RELAYS.keys(): - outermost_transport = get_outermost_transport(transport) - if outermost_transport not in REGSETS_IPV4: - REGSETS_IPV4[outermost_transport] = RegSet() - REGSETS_IPV6[outermost_transport] = RegSet() + try: + with open(options.relay_filename) as fp: + parse_relay_file(SERVERS, fp) + except ValueError as e: + print >> sys.stderr, u"Could not parse file '%s': %s" % (repr(a), str(e)) + sys.exit(1)
# Setup log file if options.log_filename: @@ -499,7 +538,8 @@ obfs2|websocket 1.4.6.1:4123\ server = Server(addrinfo[4], Handler)
log(u"start on %s" % fac.format_addr(addrinfo[4])) - log(u"using relays %s" % str(RELAYS)) + log(u"using IPv4 relays %s" % str(SERVERS[socket.AF_INET]._endpoints)) + log(u"using IPv6 relays %s" % str(SERVERS[socket.AF_INET6]._endpoints))
if options.daemonize: log(u"daemonizing") diff --git a/facilitator/facilitator-test b/facilitator/facilitator-test index e39cecd..b81b84e 100755 --- a/facilitator/facilitator-test +++ b/facilitator/facilitator-test @@ -7,7 +7,7 @@ import tempfile import time import unittest
-from facilitator import parse_transport_chain +from facilitator import Transport, Reg, Endpoints, parse_relay_file import fac
FACILITATOR_HOST = "127.0.0.1" @@ -23,11 +23,169 @@ def gimme_socket(host, port): s.connect(addrinfo[4]) return s
+class EndpointsTest(unittest.TestCase): + + def setUp(self): + self.pts = Endpoints(af=socket.AF_INET) + + def _observeProxySupporting(self, *supported_suf): + # semantically observe the existence of a proxy, to make our intent + # a bit clearer than simply calling findPrefixesForSuffixes + self.pts._findPrefixesForSuffixes(*supported_suf) + + def test_addEndpoints_twice(self): + self.pts.addEndpoint("A", "a|b|p") + self.assertFalse(self.pts.addEndpoint("A", "zzz")) + self.assertEquals(self.pts._endpoints["A"], Transport("a|b", "p")) + + def test_addEndpoints_lazy_indexing(self): + self.pts.addEndpoint("A", "a|b|p") + default_index = {"websocket": set()} # we always index known_suffixes + + # no index until we've asked for it + self.assertEquals(self.pts._indexes, default_index) + self._observeProxySupporting("p") + self.assertEquals(self.pts._indexes["p"], set("A")) + + # indexes are updated correctly after observing new addresses + self.pts.addEndpoint("B", "c|p") + self.assertEquals(self.pts._indexes["p"], set("AB")) + + # indexes are updated correctly after observing new proxies + self.pts.addEndpoint("C", "a|q") + self._observeProxySupporting("q") + self.assertEquals(self.pts._indexes["q"], set("C")) + + def test_supports_default(self): + # we know there are websocket-capable proxies out there; + # support them implicitly without needing to see a proxy. + self.pts.addEndpoint("A", "obfs3|websocket") + self.assertTrue(self.pts.supports("obfs3|websocket")) + self.assertFalse(self.pts.supports("xxx|websocket")) + self.assertFalse(self.pts.supports("websocket")) + self.assertFalse(self.pts.supports("unknownwhat")) + # doesn't matter what the first part is + self.pts.addEndpoint("B", "xxx|websocket") + self.assertTrue(self.pts.supports("xxx|websocket")) + + def test_supports_seen_proxy(self): + # OTOH if some 3rd-party proxy decides to implement its own transport + # we are fully capable of supporting them too, but only if we have + # an endpoint that also speaks it. + self.assertFalse(self.pts.supports("obfs3|unknownwhat")) + suf = self._observeProxySupporting("unknownwhat") + self.assertFalse(self.pts.supports("obfs3|unknownwhat")) + self.pts.addEndpoint("A", "obfs3|unknownwhat") + self.assertTrue(self.pts.supports("obfs3|unknownwhat")) + self.assertFalse(self.pts.supports("obfs2|unknownwhat")) + + def _test_serveReg_maxserve_infinite_roundrobin(self): + # case for servers, they never exhaust + self.pts.addEndpoint("A", "a|p") + self.pts.addEndpoint("B", "a|p") + self.pts.addEndpoint("C", "a|p") + for i in xrange(64): # 64 is infinite ;) + served = set() + served.add(self.pts._serveReg("ABC")) + served.add(self.pts._serveReg("ABC")) + served.add(self.pts._serveReg("ABC")) + self.assertEquals(served, set("ABC")) + + def _test_serveReg_maxserve_finite_exhaustion(self): + # case for clients, we don't want to keep serving them + self.pts = Endpoints(af=socket.AF_INET, maxserve=5) + self.pts.addEndpoint("A", "a|p") + self.pts.addEndpoint("B", "a|p") + self.pts.addEndpoint("C", "a|p") + # test getNumUnservedEndpoints whilst we're at it + self.assertEquals(self.pts.getNumUnservedEndpoints(), 3) + for i in xrange(5): + served = set() + served.add(self.pts._serveReg("ABC")) + served.add(self.pts._serveReg("ABC")) + served.add(self.pts._serveReg("ABC")) + self.assertEquals(served, set("ABC")) + self.assertEquals(None, self.pts._serveReg("ABC")) + self.assertEquals(self.pts.getNumUnservedEndpoints(), 0) + + def test_match_normal(self): + self.pts.addEndpoint("A", "a|p") + self.pts2 = Endpoints(af=socket.AF_INET) + self.pts2.addEndpoint("B", "a|p") + self.pts2.addEndpoint("C", "b|p") + self.pts2.addEndpoint("D", "a|q") + expected = (Reg("A", Transport("a","p")), Reg("B", Transport("a","p"))) + empty = Endpoints.EMPTY_MATCH + self.assertEquals(expected, Endpoints.match(self.pts, self.pts2, ["p"])) + self.assertEquals(empty, Endpoints.match(self.pts, self.pts2, ["x"])) + + def test_match_unequal_client_server(self): + self.pts.addEndpoint("A", "a|p") + self.pts2 = Endpoints(af=socket.AF_INET) + self.pts2.addEndpoint("B", "a|q") + expected = (Reg("A", Transport("a","p")), Reg("B", Transport("a","q"))) + empty = Endpoints.EMPTY_MATCH + self.assertEquals(expected, Endpoints.match(self.pts, self.pts2, ["p", "q"])) + self.assertEquals(empty, Endpoints.match(self.pts, self.pts2, ["p"])) + self.assertEquals(empty, Endpoints.match(self.pts, self.pts2, ["q"])) + self.assertEquals(empty, Endpoints.match(self.pts, self.pts2, ["x"])) + + def test_match_raw_server(self): + self.pts.addEndpoint("A", "p") + self.pts2 = Endpoints(af=socket.AF_INET) + self.pts2.addEndpoint("B", "p") + expected = (Reg("A", Transport("","p")), Reg("B", Transport("","p"))) + empty = Endpoints.EMPTY_MATCH + self.assertEquals(expected, Endpoints.match(self.pts, self.pts2, ["p"])) + self.assertEquals(empty, Endpoints.match(self.pts, self.pts2, ["x"])) + + def test_match_many_prefixes(self): + self.pts.addEndpoint("A", "a|p") + self.pts.addEndpoint("B", "b|p") + self.pts.addEndpoint("C", "p") + self.pts2 = Endpoints(af=socket.AF_INET) + self.pts2.addEndpoint("D", "a|p") + self.pts2.addEndpoint("E", "b|p") + self.pts2.addEndpoint("F", "p") + # this test ensures we have a sane policy for selecting between prefix pools + expected = set() + expected.add((Reg("A", Transport("a","p")), Reg("D", Transport("a","p")))) + expected.add((Reg("B", Transport("b","p")), Reg("E", Transport("b","p")))) + expected.add((Reg("C", Transport("","p")), Reg("F", Transport("","p")))) + result = set() + result.add(Endpoints.match(self.pts, self.pts2, ["p"])) + result.add(Endpoints.match(self.pts, self.pts2, ["p"])) + result.add(Endpoints.match(self.pts, self.pts2, ["p"])) + empty = Endpoints.EMPTY_MATCH + self.assertEquals(expected, result) + self.assertEquals(empty, Endpoints.match(self.pts, self.pts2, ["x"])) + self.assertEquals(empty, Endpoints.match(self.pts, self.pts2, ["x"])) + self.assertEquals(empty, Endpoints.match(self.pts, self.pts2, ["x"])) + class FacilitatorTest(unittest.TestCase):
- def test_parse_transport_chain(self): - self.assertEquals(parse_transport_chain("a"), ("a",)) - self.assertEquals(parse_transport_chain("a|b|c"), ("a","b","c")) + def test_transport_parse(self): + self.assertEquals(Transport.parse("a"), Transport("", "a")) + self.assertEquals(Transport.parse("|a"), Transport("", "a")) + self.assertEquals(Transport.parse("a|b|c"), Transport("a|b","c")) + self.assertEquals(Transport.parse(Transport("a|b","c")), Transport("a|b","c")) + self.assertRaises(ValueError, Transport, "", "") + self.assertRaises(ValueError, Transport, "a", "") + self.assertRaises(ValueError, Transport.parse, "") + self.assertRaises(ValueError, Transport.parse, "|") + self.assertRaises(ValueError, Transport.parse, "a|") + self.assertRaises(ValueError, Transport.parse, ["a"]) + self.assertRaises(ValueError, Transport.parse, [Transport("a", "b")]) + + def test_parse_relay_file(self): + fp = StringIO() + fp.write("websocket 0.0.1.0:1\n") + fp.flush() + fp.seek(0) + af = socket.AF_INET + servers = { af: Endpoints(af=af) } + parse_relay_file(servers, fp) + self.assertEquals(servers[af]._endpoints, {('0.0.1.0', 1): Transport('', 'websocket')})
class FacilitatorProcTest(unittest.TestCase): IPV4_CLIENT_ADDR = ("1.1.1.1", 9000)