commit 7787056bc7055bfe8efc5df0af20ca9e13a1ffaa Author: David Fifield david@bamsoftware.com Date: Sat Feb 1 17:03:33 2014 -0800
Add flashproxy.util.addr_family function. --- flashproxy/test/test_util.py | 13 ++++++++++++- flashproxy/util.py | 9 +++++++-- 2 files changed, 19 insertions(+), 3 deletions(-)
diff --git a/flashproxy/test/test_util.py b/flashproxy/test/test_util.py index 935dd1f..e095f38 100644 --- a/flashproxy/test/test_util.py +++ b/flashproxy/test/test_util.py @@ -1,8 +1,9 @@ #!/usr/bin/env python
+import socket import unittest
-from flashproxy.util import parse_addr_spec, canonical_ip +from flashproxy.util import parse_addr_spec, canonical_ip, addr_family
class ParseAddrSpecTest(unittest.TestCase): def test_ipv4(self): @@ -39,5 +40,15 @@ class ParseAddrSpecTest(unittest.TestCase): """Test that canonical_ip does not do DNS resolution by default.""" self.assertRaises(ValueError, canonical_ip, *parse_addr_spec("example.com:80"))
+class AddrFamilyTest(unittest.TestCase): + def test_ipv4(self): + self.assertEqual(addr_family("1.2.3.4"), socket.AF_INET) + + def test_ipv6(self): + self.assertEqual(addr_family("1:2::3:4"), socket.AF_INET6) + + def test_name(self): + self.assertRaises(socket.gaierror, addr_family, "localhost") + if __name__ == "__main__": unittest.main() diff --git a/flashproxy/util.py b/flashproxy/util.py index a53bdad..63cdef5 100644 --- a/flashproxy/util.py +++ b/flashproxy/util.py @@ -95,6 +95,12 @@ def canonical_ip(host, port, af=0): except that the host param must already be an IP address.""" return resolve_to_ip(host, port, af, gai_flags=socket.AI_NUMERICHOST)
+def addr_family(ip): + """Return the address family of an IP address. Raises socket.gaierror if ip + is not a numeric IP.""" + addrs = socket.getaddrinfo(ip, 0, 0, socket.SOCK_STREAM, socket.IPPROTO_TCP, socket.AI_NUMERICHOST) + return addrs[0][0] + def format_addr(addr): host, port = addr host_str = u"" @@ -102,8 +108,7 @@ def format_addr(addr): if host is not None: # Numeric IPv6 address? try: - addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM, socket.IPPROTO_TCP, socket.AI_NUMERICHOST) - af = addrs[0][0] + af = addr_family(host) except socket.gaierror, e: af = 0 if af == socket.AF_INET6: