commit 2658f6e3eac20a4f7c20e2c6b90ddf253731dce1
Author: Ximin Luo <infinity0(a)gmx.com>
Date: Mon Oct 7 16:47:53 2013 +0100
Reimplement matching algorithm using an Endpoints datastruct for both client/server
- select prefix pool based on least-well-served rather than arbitrarily
- don't attempt a match until a proxy is available to service the request
- don't match ipv6 proxies to ipv4 servers
---
facilitator/facilitator | 428 +++++++++++++++++++++++-------------------
facilitator/facilitator-test | 166 +++++++++++++++-
2 files changed, 396 insertions(+), 198 deletions(-)
diff --git a/facilitator/facilitator b/facilitator/facilitator
index e44047e..d4088ff 100755
--- a/facilitator/facilitator
+++ b/facilitator/facilitator
@@ -8,6 +8,7 @@ import sys
import threading
import time
import traceback
+from collections import namedtuple
import fac
@@ -65,26 +66,6 @@ again. Listen on 127.0.0.1 and port PORT (by default %(port)d).
"log": DEFAULT_LOG_FILENAME,
}
-def num_relays():
- return sum(len(x) for x in RELAYS.values())
-
-def parse_transport_chain(spec):
- """Parse a transport chain string and return a tuple of individual
- transports, each of which is a string.
- >>> parse_transport_chain("obfs3|websocket")
- ('obfs3', 'websocket')
- """
- assert(spec)
- return tuple(spec.split("|"))
-
-def get_outermost_transport(transports):
- """Given a transport chain tuple, return the last element.
- >>> get_outermost_transport(("obfs3", "websocket"))
- 'websocket'
- """
- assert(transports)
- return transports[-1]
-
def safe_str(s):
"""Return "[scrubbed]" if options.safe_logging is true, and s otherwise."""
if options.safe_logging:
@@ -98,87 +79,191 @@ def log(msg):
print >> options.log_file, (u"%s %s" % (time.strftime(LOG_DATE_FORMAT), msg)).encode("UTF-8")
options.log_file.flush()
-class TCPReg(object):
- def __init__(self, host, port, transports):
- self.host = host
- self.port = port
- self.transports = transports
- # Get a relay for this registration. Throw UnknownTransport if
- # could not be found.
- self.relay = self._get_matching_relay()
- def _get_matching_relay(self):
- """Return a matching relay address for this registration. Raise
- UnknownTransport if a relay with a matching transport chain could not be
- found."""
- if self.transports not in RELAYS:
- raise UnknownTransport("Can't find relay with transport chain: %s" % self.transports)
-
- # Maybe this should be a random pick from the set of all the
- # eligible relays. But let's keep it deterministic for now,
- # and return the first one.
-
- # return random.choice(RELAYS[self.transports])
- return RELAYS[self.transports][0]
+class Transport(namedtuple("Transport", "prefix suffix")):
+ @classmethod
+ def parse(cls, transport):
+ if isinstance(transport, cls):
+ return transport
+ elif type(transport) == str:
+ if "|" in transport:
+ prefix, suffix = transport.rsplit("|", 1)
+ else:
+ prefix, suffix = "", transport
+ return cls(prefix, suffix)
+ else:
+ raise ValueError("could not parse transport: %s" % transport)
- def __unicode__(self):
- return fac.format_addr((self.host, self.port))
+ def __init__(self, prefix, suffix):
+ if not suffix:
+ raise ValueError("suffix (proxy) part of transport must be non-empty: %s" % str(self))
def __str__(self):
- return unicode(self).encode("UTF-8")
+ return "%s|%s" % (self.prefix, self.suffix) if self.prefix else self.suffix
- def __cmp__(self, other):
- if isinstance(other, TCPReg):
- # XXX is this correct comparison?
- return cmp((self.host, self.port, self.transports), (other.host, other.port, other.transports))
- else:
- return False
-class Reg(object):
- @staticmethod
- def parse(spec, transports, defhost = None, defport = None):
+class Reg(namedtuple("Reg", "addr transport")):
+ @classmethod
+ def parse(cls, spec, transport, defhost = None, defport = None):
host, port = fac.parse_addr_spec(spec, defhost, defport)
- return TCPReg(host, port, transports)
+ return cls((host, port), Transport.parse(transport))
-class RegSet(object):
- def __init__(self):
- self.tiers = [[] for i in range(MAX_PROXIES_PER_CLIENT)]
- self.cv = threading.Condition()
- def add(self, reg):
- self.cv.acquire()
- try:
- for tier in self.tiers:
- if reg in tier:
- break
- else:
- self.tiers[0].append(reg)
- self.cv.notify()
- return True
- return False
- finally:
- self.cv.release()
-
- def fetch(self):
- self.cv.acquire()
- try:
- for i in range(len(self.tiers)):
- tier = self.tiers[i]
- if tier:
- reg = tier.pop(0)
- if i + 1 < len(self.tiers):
- self.tiers[i+1].append(reg)
- return reg
+class Endpoints(object):
+ """
+ Tracks endpoints (either client/server) and the transport chains that
+ they support.
+ """
+
+ matchingLock = threading.Condition()
+
+ def __init__(self, af, maxserve=float("inf"), known_suf=("websocket",)):
+ self.af = af
+ self._maxserve = maxserve
+ self._endpoints = {} # address -> transport
+ self._indexes = {} # suffix -> [ addresses ]
+ self._served = {} # address -> num_times_served
+ self._cv = threading.Condition()
+ self.known_suf = set(known_suf)
+ for suf in self.known_suf:
+ self._ensureIndexForSuffix(suf)
+
+ def getNumEndpoints(self):
+ """:returns: the number of endpoints known to us."""
+ with self._cv:
+ return len(self._endpoints)
+
+ def getNumUnservedEndpoints(self):
+ """:returns: the number of unserved endpoints known to us."""
+ with self._cv:
+ return len(filter(lambda t: t == 0, self._served.itervalues()))
+
+ def addEndpoint(self, addr, transport):
+ """Add an endpoint.
+
+ :param addr: Address of endpoint, usage-dependent.
+ :param list transports: List of transports.
+ :returns: False if the address is already known, in which case no
+ update is made to its supported transports, else True.
+ """
+ transport = Transport.parse(transport)
+ with self._cv:
+ if addr in self._endpoints: return False
+ self._endpoints[addr] = transport
+ self._served[addr] = 0
+ self._addAddrIntoIndexes(addr)
+ self._cv.notify()
+ return True
+
+ def delEndpoint(self, addr):
+ """Forget an endpoint.
+
+ :param addr: Address of endpoint, usage-dependent.
+ :returns: False if the address was already forgotten, else True.
+ """
+ with self._cv:
+ if addr not in self._endpoints: return False
+ self._delAddrFromIndexes(addr)
+ del self._served[addr]
+ del self._endpoints[addr]
+ self._cv.notify()
+ return True
+
+ def supports(self, transport):
+ """
+ Estimate whether we support the given transport. May give false
+ positives, but doing a proper match later on will catch these.
+
+ :returns: True if we know, or have met, proxies that might be able
+ to satisfy the requested transport against our known endpoints.
+ """
+ transport = Transport.parse(transport)
+ with self._cv:
+ known_pre = self._findPrefixesForSuffixes(*self.known_suf).keys()
+ pre, suf = transport.prefix, transport.suffix
+ return pre in known_pre and suf in self.known_suf
+
+ def _findPrefixesForSuffixes(self, *supported_suf):
+ """
+ :returns: { prefix: [addr] }, where each address supports some suffix
+ from supported_suf. TODO(infinity0): describe better
+ """
+ self.known_suf.update(supported_suf)
+ prefixes = {}
+ for suf in supported_suf:
+ self._ensureIndexForSuffix(suf)
+ for addr in self._indexes[suf]:
+ pre = self._endpoints[addr].prefix
+ prefixes.setdefault(pre, set()).add(addr)
+ return prefixes
+
+ def _avServed(self, addrpool):
+ return sum(self._served[a] for a in addrpool) / float(len(addrpool))
+
+ def _serveReg(self, addrpool):
+ """
+ :param list addrpool: List of candidate addresses.
+ :returns: An address of an endpoint from the given pool, or None if all
+ endpoints have already been served _maxserve times. The serve
+ counter for that address is also incremented.
+ """
+ if not addrpool: return None
+ prio_addr = min(addrpool, key=lambda a: self._served[a])
+ if self._served[prio_addr] < self._maxserve:
+ self._served[prio_addr] += 1
+ return prio_addr
+ else:
return None
- finally:
- self.cv.release()
- def __len__(self):
- self.cv.acquire()
- try:
- return sum(len(tier) for tier in self.tiers)
- finally:
- self.cv.release()
+ def _ensureIndexForSuffix(self, suf):
+ if suf in self._indexes: return
+ addrs = set(addr for addr, transport in self._endpoints.iteritems()
+ if transport.suffix == suf)
+ self._indexes[suf] = addrs
+
+ def _addAddrIntoIndexes(self, addr):
+ suf = self._endpoints[addr].suffix
+ if suf in self._indexes: self._indexes[suf].add(addr)
+
+ def _delAddrFromIndexes(self, addr):
+ suf = self._endpoints[addr].suffix
+ if suf in self._indexes: self._indexes[suf].remove(addr)
+
+ def _prefixesForTransport(self, transport, *supported_suf):
+ for suf in supported_suf:
+ if not suf:
+ yield transport
+ elif transport[-len(suf):] == suf:
+ yield transport[:-len(suf)]
+
+ EMPTY_MATCH = (None, None)
+ @staticmethod
+ def match(ptsClient, ptsServer, supported_suf):
+ """
+ :returns: A tuple (client Reg, server Reg) arbitrarily selected from
+ the available endpoints that can satisfy supported_suf.
+ """
+ if ptsClient.af != ptsServer.af:
+ raise ValueError("address family not equal!")
+ # need to operate on both structures
+ # so hold both locks plus a pair-wise lock
+ with Endpoints.matchingLock, ptsClient._cv, ptsServer._cv:
+ server_pre = ptsServer._findPrefixesForSuffixes(*supported_suf)
+ client_pre = ptsClient._findPrefixesForSuffixes(*supported_suf)
+ both = set(server_pre.keys()) & set(client_pre.keys())
+ if not both: return Endpoints.EMPTY_MATCH
+ # pick the prefix whose client address pool is least well-served
+ # TODO: this may be manipulated by clients, needs research
+ assert all(client_pre.itervalues()) # no pool is empty
+ pre = min(both, key=lambda p: ptsClient._avServed(client_pre[p]))
+ client_addr = ptsClient._serveReg(client_pre[pre])
+ if not client_addr: return Endpoints.EMPTY_MATCH
+ server_addr = ptsServer._serveReg(server_pre[pre])
+ # assume servers never run out
+ client_transport = ptsClient._endpoints[client_addr]
+ server_transport = ptsServer._endpoints[server_addr]
+ return Reg(client_addr, client_transport), Reg(server_addr, server_transport)
+
class Handler(SocketServer.StreamRequestHandler):
def __init__(self, *args, **kwargs):
@@ -277,16 +362,19 @@ class Handler(SocketServer.StreamRequestHandler):
return self.error(u"TRANSPORT missing FROM param")
try:
- reg = get_reg_for_proxy(proxy_addr, transport_list)
+ client_reg, relay_reg = get_match_for_proxy(proxy_addr, transport_list)
except Exception as e:
return self.error(u"error getting match for proxy address %s: %%(cause)s" % safe_str(repr(proxy_spec)), e)
check_back_in = get_check_back_in_for_proxy(proxy_addr)
- if reg:
- log(u"proxy (%s) gets client '%s' (transports: %s) (num relays: %s) (remaining regs: %d/%d)" %
- (safe_str(repr(proxy_spec)), safe_str(unicode(reg)), reg.transports, num_relays(), num_unhandled_regs(), num_regs()))
- print >> self.wfile, fac.render_transaction("OK", ("CLIENT", str(reg)), ("RELAY", fac.format_addr(reg.relay)), ("CHECK-BACK-IN", str(check_back_in)))
+ if client_reg:
+ log(u"proxy (%s) gets client '%s' (supported transports: %s) (num relays: %s) (remaining regs: %d/%d)" %
+ (safe_str(repr(proxy_spec)), safe_str(repr(client_reg.addr)), transport_list, num_relays(), num_unhandled_regs(), num_regs()))
+ print >> self.wfile, fac.render_transaction("OK",
+ ("CLIENT", fac.format_addr(client_reg.addr)),
+ ("RELAY", fac.format_addr(relay_reg.addr)),
+ ("CHECK-BACK-IN", str(check_back_in)))
else:
log(u"proxy (%s) gets none" % safe_str(repr(proxy_spec)))
print >> self.wfile, fac.render_transaction("NONE", ("CHECK-BACK-IN", str(check_back_in)))
@@ -297,27 +385,21 @@ class Handler(SocketServer.StreamRequestHandler):
# Example: PUT CLIENT="1.1.1.1:5555" TRANSPORT_CHAIN="obfs3|websocket"
def do_PUT(self, params):
# Check out if we recognize the transport chain in this registration request
- transports_spec = fac.param_first("TRANSPORT_CHAIN", params)
- if transports_spec is None:
+ transport = fac.param_first("TRANSPORT_CHAIN", params)
+ if transport is None:
return self.error(u"PUT missing TRANSPORT_CHAIN param")
- transports = parse_transport_chain(transports_spec)
-
+ transport = Transport.parse(transport)
# See if we have relays that support this transport chain
- if transports not in RELAYS:
- log(u"Unrecognized transport chain: %s" % transports)
- self.send_error() # XXX can we tell the flashproxy client of this error?
- return False
- # if we have relays that support this transport chain, we
- # certainly have a regset for its outermost transport too.
- assert(get_outermost_transport(transports) in REGSETS_IPV4)
+ if all(not pts.supports(transport) for pts in SERVERS.itervalues()):
+ return self.error(u"Unrecognized transport: %s" % transport)
client_spec = fac.param_first("CLIENT", params)
if client_spec is None:
return self.error(u"PUT missing CLIENT param")
try:
- reg = Reg.parse(client_spec, transports)
+ reg = Reg.parse(client_spec, transport)
except (UnknownTransport, ValueError) as e:
# XXX should we throw a better error message to the client? Is it possible?
return self.error(u"syntax error in %s: %%(cause)s" % safe_str(repr(client_spec)), e)
@@ -328,9 +410,9 @@ class Handler(SocketServer.StreamRequestHandler):
return self.error(u"error putting reg %s: %%(cause)s" % safe_str(repr(client_spec)), e)
if ok:
- log(u"client %s (transports: %s) (remaining regs: %d/%d)" % (safe_str(unicode(reg)), reg.transports, num_unhandled_regs(), num_regs()))
+ log(u"client %s (transports: %s) (remaining regs: %d/%d)" % (safe_str(unicode(reg)), reg.transport, num_unhandled_regs(), num_regs()))
else:
- log(u"client %s (already present) (transports: %s) (remaining regs: %d/%d)" % (safe_str(unicode(reg)), reg.transports, num_unhandled_regs(), num_regs()))
+ log(u"client %s (already present) (transports: %s) (remaining regs: %d/%d)" % (safe_str(unicode(reg)), reg.transport, num_unhandled_regs(), num_regs()))
self.send_ok()
return True
@@ -341,48 +423,29 @@ class Server(SocketServer.ThreadingMixIn, SocketServer.TCPServer):
allow_reuse_address = True
# Registration sets per-outermost-transport
-# {"websocket" : <RegSet for websocket>, "webrtc" : <RegSet for webrtc>}
-REGSETS_IPV4 = {}
-REGSETS_IPV6 = {}
+# Addresses are plain tuples (str(host), int(port))
-def num_regs():
- """Return the total number of registrations."""
- num_regs = 0
+CLIENTS = {
+ socket.AF_INET: Endpoints(af=socket.AF_INET, maxserve=MAX_PROXIES_PER_CLIENT),
+ socket.AF_INET6: Endpoints(af=socket.AF_INET6, maxserve=MAX_PROXIES_PER_CLIENT)
+}
+
+SERVERS = {
+ socket.AF_INET: Endpoints(af=socket.AF_INET),
+ socket.AF_INET6: Endpoints(af=socket.AF_INET6)
+}
- # Iterate the regsets of each regset-dictionary, and count their
- # registrations.
- for regset in REGSETS_IPV4.values():
- num_regs += len(regset)
- for regset in REGSETS_IPV6.values():
- num_regs += len(regset)
+def num_relays():
+ """Return the total number of relays."""
+ return sum(pts.getNumEndpoints() for pts in SERVERS.itervalues())
- return num_regs
+def num_regs():
+ """Return the total number of registrations."""
+ return sum(pts.getNumEndpoints() for pts in CLIENTS.itervalues())
def num_unhandled_regs():
"""Return the total number of unhandled registrations."""
- num_regs = 0
-
- # Iterate the regsets of each regset-dictionary, and count their
- # unhandled registrations. The first tier of each regset contains
- # the registrations with no assigned proxy.
- for regset in REGSETS_IPV4.values():
- num_regs += len(regset.tiers[0])
- for regset in REGSETS_IPV6.values():
- num_regs += len(regset.tiers[0])
-
- return num_regs
-
-def get_regs(af, transport):
- """Return the correct regs pool for the given address family and transport."""
- if transport not in REGSETS_IPV4:
- raise UnknownTransport("unknown transport '%s'" % transport)
-
- if af == socket.AF_INET:
- return REGSETS_IPV4[transport]
- elif af == socket.AF_INET6:
- return REGSETS_IPV6[transport]
- else:
- raise ValueError("unknown address family %d" % af)
+ return sum(pts.getNumUnservedEndpoints() for pts in CLIENTS.itervalues())
def addr_af(addr_str):
"""Return the address family for an address string. This is a plain string,
@@ -390,26 +453,12 @@ def addr_af(addr_str):
addrs = socket.getaddrinfo(addr_str, 0, 0, socket.SOCK_STREAM, socket.IPPROTO_TCP, socket.AI_NUMERICHOST)
return addrs[0][0]
-def get_reg_for_proxy(proxy_addr, transport_list):
- """Get a client registration appropriate for the given proxy (one
- of a matching address family). If 'transports' is set, try to find
- a client registration that supports the outermost transport of a
- transport chain."""
- # XXX How should we prioritize transport matching? We currently
- # just iterate the transport list that was provided by the flashproxy
- for transport in transport_list:
- addr_str = proxy_addr[0]
- af = addr_af(addr_str)
-
- try:
- REGS = get_regs(af, transport)
- except UnknownTransport as e:
- log(u"%s" % e)
- continue # move to the next transport
-
- return REGS.fetch()
-
- raise UnknownTransport("Could not find registration for transport list: %s" % str(transport_list))
+def get_match_for_proxy(proxy_addr, transport_list):
+ af = addr_af(proxy_addr[0])
+ try:
+ return Endpoints.match(CLIENTS[af], SERVERS[af], transport_list)
+ except ValueError as e:
+ raise UnknownTransport("Could not find registration for transport list: %s: %s" % (transport_list, e))
def get_check_back_in_for_proxy(proxy_addr):
"""Get a CHECK-BACK-IN interval suitable for this proxy."""
@@ -417,29 +466,24 @@ def get_check_back_in_for_proxy(proxy_addr):
def put_reg(reg):
"""Add a registration."""
- addr_str = reg.host
- af = addr_af(addr_str)
- REGS = get_regs(af, get_outermost_transport(reg.transports))
- return REGS.add(reg)
+ af = addr_af(reg.addr[0])
+ return CLIENTS[af].addEndpoint(reg.addr, reg.transport)
-def parse_relay_file(filename):
+def parse_relay_file(servers, fp):
"""Parse a file containing Tor relays that we can point proxies to.
Throws ValueError on a parsing error. Each line contains a transport chain
and an address, for example
obfs2|websocket 1.4.6.1:4123
"""
- relays = {}
- with open(filename) as f:
- for line in f:
- try:
- transport_spec, addr_spec = line.strip().split()
- except ValueError, e:
- raise ValueError("Wrong line format: %s." % repr(line))
- addr = fac.parse_addr_spec(addr_spec, defport=DEFAULT_RELAY_PORT, resolve=True)
- transports = parse_transport_chain(transport_spec)
- relays.setdefault(transports, [])
- relays[transports].append(addr)
- return relays
+ for line in fp.readlines():
+ try:
+ transport_spec, addr_spec = line.strip().split()
+ except ValueError, e:
+ raise ValueError("Wrong line format: %s." % repr(line))
+ addr = fac.parse_addr_spec(addr_spec, defport=DEFAULT_RELAY_PORT, resolve=True)
+ transport = Transport.parse(transport_spec)
+ af = addr_af(addr[0])
+ servers[af].addEndpoint(addr, transport)
def main():
opts, args = getopt.gnu_getopt(sys.argv[1:], "dhl:p:r:",
@@ -474,17 +518,12 @@ obfs2|websocket 1.4.6.1:4123\
"""
sys.exit(1)
- RELAYS.update(parse_relay_file(options.relay_filename))
-
- if not RELAYS:
- print >> sys.stderr, u"Warning: no relays configured."
-
- # Create RegSets for our supported transports
- for transport in RELAYS.keys():
- outermost_transport = get_outermost_transport(transport)
- if outermost_transport not in REGSETS_IPV4:
- REGSETS_IPV4[outermost_transport] = RegSet()
- REGSETS_IPV6[outermost_transport] = RegSet()
+ try:
+ with open(options.relay_filename) as fp:
+ parse_relay_file(SERVERS, fp)
+ except ValueError as e:
+ print >> sys.stderr, u"Could not parse file '%s': %s" % (repr(a), str(e))
+ sys.exit(1)
# Setup log file
if options.log_filename:
@@ -499,7 +538,8 @@ obfs2|websocket 1.4.6.1:4123\
server = Server(addrinfo[4], Handler)
log(u"start on %s" % fac.format_addr(addrinfo[4]))
- log(u"using relays %s" % str(RELAYS))
+ log(u"using IPv4 relays %s" % str(SERVERS[socket.AF_INET]._endpoints))
+ log(u"using IPv6 relays %s" % str(SERVERS[socket.AF_INET6]._endpoints))
if options.daemonize:
log(u"daemonizing")
diff --git a/facilitator/facilitator-test b/facilitator/facilitator-test
index e39cecd..b81b84e 100755
--- a/facilitator/facilitator-test
+++ b/facilitator/facilitator-test
@@ -7,7 +7,7 @@ import tempfile
import time
import unittest
-from facilitator import parse_transport_chain
+from facilitator import Transport, Reg, Endpoints, parse_relay_file
import fac
FACILITATOR_HOST = "127.0.0.1"
@@ -23,11 +23,169 @@ def gimme_socket(host, port):
s.connect(addrinfo[4])
return s
+class EndpointsTest(unittest.TestCase):
+
+ def setUp(self):
+ self.pts = Endpoints(af=socket.AF_INET)
+
+ def _observeProxySupporting(self, *supported_suf):
+ # semantically observe the existence of a proxy, to make our intent
+ # a bit clearer than simply calling findPrefixesForSuffixes
+ self.pts._findPrefixesForSuffixes(*supported_suf)
+
+ def test_addEndpoints_twice(self):
+ self.pts.addEndpoint("A", "a|b|p")
+ self.assertFalse(self.pts.addEndpoint("A", "zzz"))
+ self.assertEquals(self.pts._endpoints["A"], Transport("a|b", "p"))
+
+ def test_addEndpoints_lazy_indexing(self):
+ self.pts.addEndpoint("A", "a|b|p")
+ default_index = {"websocket": set()} # we always index known_suffixes
+
+ # no index until we've asked for it
+ self.assertEquals(self.pts._indexes, default_index)
+ self._observeProxySupporting("p")
+ self.assertEquals(self.pts._indexes["p"], set("A"))
+
+ # indexes are updated correctly after observing new addresses
+ self.pts.addEndpoint("B", "c|p")
+ self.assertEquals(self.pts._indexes["p"], set("AB"))
+
+ # indexes are updated correctly after observing new proxies
+ self.pts.addEndpoint("C", "a|q")
+ self._observeProxySupporting("q")
+ self.assertEquals(self.pts._indexes["q"], set("C"))
+
+ def test_supports_default(self):
+ # we know there are websocket-capable proxies out there;
+ # support them implicitly without needing to see a proxy.
+ self.pts.addEndpoint("A", "obfs3|websocket")
+ self.assertTrue(self.pts.supports("obfs3|websocket"))
+ self.assertFalse(self.pts.supports("xxx|websocket"))
+ self.assertFalse(self.pts.supports("websocket"))
+ self.assertFalse(self.pts.supports("unknownwhat"))
+ # doesn't matter what the first part is
+ self.pts.addEndpoint("B", "xxx|websocket")
+ self.assertTrue(self.pts.supports("xxx|websocket"))
+
+ def test_supports_seen_proxy(self):
+ # OTOH if some 3rd-party proxy decides to implement its own transport
+ # we are fully capable of supporting them too, but only if we have
+ # an endpoint that also speaks it.
+ self.assertFalse(self.pts.supports("obfs3|unknownwhat"))
+ suf = self._observeProxySupporting("unknownwhat")
+ self.assertFalse(self.pts.supports("obfs3|unknownwhat"))
+ self.pts.addEndpoint("A", "obfs3|unknownwhat")
+ self.assertTrue(self.pts.supports("obfs3|unknownwhat"))
+ self.assertFalse(self.pts.supports("obfs2|unknownwhat"))
+
+ def _test_serveReg_maxserve_infinite_roundrobin(self):
+ # case for servers, they never exhaust
+ self.pts.addEndpoint("A", "a|p")
+ self.pts.addEndpoint("B", "a|p")
+ self.pts.addEndpoint("C", "a|p")
+ for i in xrange(64): # 64 is infinite ;)
+ served = set()
+ served.add(self.pts._serveReg("ABC"))
+ served.add(self.pts._serveReg("ABC"))
+ served.add(self.pts._serveReg("ABC"))
+ self.assertEquals(served, set("ABC"))
+
+ def _test_serveReg_maxserve_finite_exhaustion(self):
+ # case for clients, we don't want to keep serving them
+ self.pts = Endpoints(af=socket.AF_INET, maxserve=5)
+ self.pts.addEndpoint("A", "a|p")
+ self.pts.addEndpoint("B", "a|p")
+ self.pts.addEndpoint("C", "a|p")
+ # test getNumUnservedEndpoints whilst we're at it
+ self.assertEquals(self.pts.getNumUnservedEndpoints(), 3)
+ for i in xrange(5):
+ served = set()
+ served.add(self.pts._serveReg("ABC"))
+ served.add(self.pts._serveReg("ABC"))
+ served.add(self.pts._serveReg("ABC"))
+ self.assertEquals(served, set("ABC"))
+ self.assertEquals(None, self.pts._serveReg("ABC"))
+ self.assertEquals(self.pts.getNumUnservedEndpoints(), 0)
+
+ def test_match_normal(self):
+ self.pts.addEndpoint("A", "a|p")
+ self.pts2 = Endpoints(af=socket.AF_INET)
+ self.pts2.addEndpoint("B", "a|p")
+ self.pts2.addEndpoint("C", "b|p")
+ self.pts2.addEndpoint("D", "a|q")
+ expected = (Reg("A", Transport("a","p")), Reg("B", Transport("a","p")))
+ empty = Endpoints.EMPTY_MATCH
+ self.assertEquals(expected, Endpoints.match(self.pts, self.pts2, ["p"]))
+ self.assertEquals(empty, Endpoints.match(self.pts, self.pts2, ["x"]))
+
+ def test_match_unequal_client_server(self):
+ self.pts.addEndpoint("A", "a|p")
+ self.pts2 = Endpoints(af=socket.AF_INET)
+ self.pts2.addEndpoint("B", "a|q")
+ expected = (Reg("A", Transport("a","p")), Reg("B", Transport("a","q")))
+ empty = Endpoints.EMPTY_MATCH
+ self.assertEquals(expected, Endpoints.match(self.pts, self.pts2, ["p", "q"]))
+ self.assertEquals(empty, Endpoints.match(self.pts, self.pts2, ["p"]))
+ self.assertEquals(empty, Endpoints.match(self.pts, self.pts2, ["q"]))
+ self.assertEquals(empty, Endpoints.match(self.pts, self.pts2, ["x"]))
+
+ def test_match_raw_server(self):
+ self.pts.addEndpoint("A", "p")
+ self.pts2 = Endpoints(af=socket.AF_INET)
+ self.pts2.addEndpoint("B", "p")
+ expected = (Reg("A", Transport("","p")), Reg("B", Transport("","p")))
+ empty = Endpoints.EMPTY_MATCH
+ self.assertEquals(expected, Endpoints.match(self.pts, self.pts2, ["p"]))
+ self.assertEquals(empty, Endpoints.match(self.pts, self.pts2, ["x"]))
+
+ def test_match_many_prefixes(self):
+ self.pts.addEndpoint("A", "a|p")
+ self.pts.addEndpoint("B", "b|p")
+ self.pts.addEndpoint("C", "p")
+ self.pts2 = Endpoints(af=socket.AF_INET)
+ self.pts2.addEndpoint("D", "a|p")
+ self.pts2.addEndpoint("E", "b|p")
+ self.pts2.addEndpoint("F", "p")
+ # this test ensures we have a sane policy for selecting between prefix pools
+ expected = set()
+ expected.add((Reg("A", Transport("a","p")), Reg("D", Transport("a","p"))))
+ expected.add((Reg("B", Transport("b","p")), Reg("E", Transport("b","p"))))
+ expected.add((Reg("C", Transport("","p")), Reg("F", Transport("","p"))))
+ result = set()
+ result.add(Endpoints.match(self.pts, self.pts2, ["p"]))
+ result.add(Endpoints.match(self.pts, self.pts2, ["p"]))
+ result.add(Endpoints.match(self.pts, self.pts2, ["p"]))
+ empty = Endpoints.EMPTY_MATCH
+ self.assertEquals(expected, result)
+ self.assertEquals(empty, Endpoints.match(self.pts, self.pts2, ["x"]))
+ self.assertEquals(empty, Endpoints.match(self.pts, self.pts2, ["x"]))
+ self.assertEquals(empty, Endpoints.match(self.pts, self.pts2, ["x"]))
+
class FacilitatorTest(unittest.TestCase):
- def test_parse_transport_chain(self):
- self.assertEquals(parse_transport_chain("a"), ("a",))
- self.assertEquals(parse_transport_chain("a|b|c"), ("a","b","c"))
+ def test_transport_parse(self):
+ self.assertEquals(Transport.parse("a"), Transport("", "a"))
+ self.assertEquals(Transport.parse("|a"), Transport("", "a"))
+ self.assertEquals(Transport.parse("a|b|c"), Transport("a|b","c"))
+ self.assertEquals(Transport.parse(Transport("a|b","c")), Transport("a|b","c"))
+ self.assertRaises(ValueError, Transport, "", "")
+ self.assertRaises(ValueError, Transport, "a", "")
+ self.assertRaises(ValueError, Transport.parse, "")
+ self.assertRaises(ValueError, Transport.parse, "|")
+ self.assertRaises(ValueError, Transport.parse, "a|")
+ self.assertRaises(ValueError, Transport.parse, ["a"])
+ self.assertRaises(ValueError, Transport.parse, [Transport("a", "b")])
+
+ def test_parse_relay_file(self):
+ fp = StringIO()
+ fp.write("websocket 0.0.1.0:1\n")
+ fp.flush()
+ fp.seek(0)
+ af = socket.AF_INET
+ servers = { af: Endpoints(af=af) }
+ parse_relay_file(servers, fp)
+ self.assertEquals(servers[af]._endpoints, {('0.0.1.0', 1): Transport('', 'websocket')})
class FacilitatorProcTest(unittest.TestCase):
IPV4_CLIENT_ADDR = ("1.1.1.1", 9000)