commit bd3c3708a404ccadb58ce8f70fdfc444799ec876 Author: David Fifield david@bamsoftware.com Date: Fri Jul 13 08:07:04 2012 -0700
Handle PUT command. --- facilitator | 132 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 files changed, 129 insertions(+), 3 deletions(-)
diff --git a/facilitator b/facilitator index 90fb1d8..dc6b752 100755 --- a/facilitator +++ b/facilitator @@ -55,6 +55,13 @@ again. Listen on 127.0.0.1 and port PORT (by default %(port)d). "log": DEFAULT_LOG_FILENAME, }
+def safe_str(s): + """Return s if options.safe_logging is true, and "[scrubbed]" otherwise.""" + if options.safe_logging: + return "[scrubbed]" + else: + return s + log_lock = threading.Lock() def log(msg): log_lock.acquire() @@ -173,6 +180,82 @@ def parse_command(line): pairs.append((key, value)) return command, tuple(pairs)
+def param_first(key, params): + for k, v in params: + if key == k: + return v + return None + +class TCPReg(object): + def __init__(self, host, port): + self.host = host + self.port = port + + def __unicode__(self): + return format_addr((self.host, self.port)) + + def __str__(self): + return unicode(self).encode("UTF-8") + + def __cmp__(self, other): + if isinstance(other, TCPReg): + return cmp((self.host, self.port), (other.host, other.port)) + else: + return False + +class Reg(object): + @staticmethod + def parse(spec, defhost = None, defport = None): + try: + af, host, port = parse_addr_spec(spec, defhost, defport) + except ValueError: + pass + else: + try: + addrs = socket.getaddrinfo(host, port, af, socket.SOCK_STREAM, socket.IPPROTO_TCP, socket.AI_NUMERICHOST) + except socket.gaierror, e: + raise ValueError("Bad host or port: "%s" "%s": %s" % (host, port, str(e))) + if not addrs: + raise ValueError("Bad host or port: "%s" "%s"" % (host, port)) + + host, port = socket.getnameinfo(addrs[0][4], socket.NI_NUMERICHOST | socket.NI_NUMERICSERV) + return TCPReg(host, int(port)) + + raise ValueError("Bad spec format: %s" % repr(spec)) + +class RegSet(object): + def __init__(self): + self.set = [] + self.cv = threading.Condition() + + def add(self, reg): + self.cv.acquire() + try: + if reg not in list(self.set): + self.set.append(reg) + self.cv.notify() + return True + else: + return False + finally: + self.cv.release() + + def fetch(self): + self.cv.acquire() + try: + if not self.set: + return None + return self.set.pop(0) + finally: + self.cv.release() + + def __len__(self): + self.cv.acquire() + try: + return len(self.set) + finally: + self.cv.release() + class Handler(SocketServer.StreamRequestHandler): def __init__(self, *args, **kwargs): self.deadline = time.time() + CLIENT_TIMEOUT @@ -224,17 +307,60 @@ class Handler(SocketServer.StreamRequestHandler): except socket.error, e: log("socket error after reading %d lines: %s" % (num_lines, str(e))) break - self.handle_line(line) + if not self.handle_line(line): + break
def handle_line(self, line): if not (len(line) > 0 and line[-1] == '\n'): raise ValueError("No newline at end of string returned by readline") - command, pairs = parse_command(line[:-1]) - print command, pairs + try: + command, params = parse_command(line[:-1]) + except ValueError, e: + log("parse_command: %s" % e) + self.send_error() + return False + + if command == "PUT": + return self.do_PUT(params) + else: + self.send_error() + return False + + def send_ok(self): + print >> self.wfile, "OK" + + def send_error(self): + print >> self.wfile, "ERROR" + + def do_PUT(self, params): + client_spec = param_first("CLIENT", params) + if client_spec is None: + log(u"PUT missing CLIENT param") + self.send_error() + return False + + # FROM + + try: + reg = Reg.parse(client_spec, self.client_address[0]) + except ValueError, e: + log(u"syntax error in %s: %s" % (safe_str(repr(client_spec)), repr(str(e)))) + self.send_error() + return False + + if REGS.add(reg): + log(u"client %s (now %d)" % (safe_str(unicode(reg)), len(REGS))) + else: + log(u"client %s (already present, now %d)" % (safe_str(unicode(reg)), len(REGS))) + + self.send_ok() + return True
class Server(SocketServer.ThreadingMixIn, SocketServer.TCPServer): allow_reuse_address = True
+REGS = RegSet() + def main(): opts, args = getopt.gnu_getopt(sys.argv[1:], "dhl:p:r:", ["debug", "help", "log=", "port=", "pidfile=", "relay=", "unsafe-logging"])