[tor-commits] [tor/master] Add support for lower-level byte counting with NSS

nickm at torproject.org nickm at torproject.org
Fri Sep 14 16:45:37 UTC 2018


commit 126819c94702df2b0eb8cbfaaf4ad81873b94019
Author: Nick Mathewson <nickm at torproject.org>
Date:   Thu Sep 13 10:18:34 2018 -0400

    Add support for lower-level byte counting with NSS
    
    This is harder than with OpenSSL, since OpenSSL counts the bytes on
    its own and NSS doesn't.  To fix this, we need to define a new
    PRFileDesc layer that has its own byte-counting support.
    
    Closes ticket 27289.
---
 src/lib/tls/include.am       |   2 +
 src/lib/tls/nss_countbytes.c | 244 +++++++++++++++++++++++++++++++++++++++++++
 src/lib/tls/nss_countbytes.h |  25 +++++
 src/lib/tls/tortls_nss.c     |  33 +++---
 src/lib/tls/tortls_st.h      |   5 +-
 5 files changed, 295 insertions(+), 14 deletions(-)

diff --git a/src/lib/tls/include.am b/src/lib/tls/include.am
index b25e2e16b..a664b29fb 100644
--- a/src/lib/tls/include.am
+++ b/src/lib/tls/include.am
@@ -12,6 +12,7 @@ src_lib_libtor_tls_a_SOURCES =			\
 
 if USE_NSS
 src_lib_libtor_tls_a_SOURCES +=			\
+	src/lib/tls/nss_countbytes.c		\
 	src/lib/tls/tortls_nss.c		\
 	src/lib/tls/x509_nss.c
 else
@@ -31,6 +32,7 @@ src_lib_libtor_tls_testing_a_CFLAGS = \
 noinst_HEADERS +=				\
 	src/lib/tls/ciphers.inc			\
 	src/lib/tls/buffers_tls.h		\
+	src/lib/tls/nss_countbytes.h		\
 	src/lib/tls/tortls.h			\
 	src/lib/tls/tortls_internal.h		\
 	src/lib/tls/tortls_st.h			\
diff --git a/src/lib/tls/nss_countbytes.c b/src/lib/tls/nss_countbytes.c
new file mode 100644
index 000000000..c72768452
--- /dev/null
+++ b/src/lib/tls/nss_countbytes.c
@@ -0,0 +1,244 @@
+/* Copyright 2018, The Tor Project Inc. */
+/* See LICENSE for licensing information */
+
+/**
+ * \file nss_countbytes.c
+ * \brief A PRFileDesc layer to let us count the number of bytes
+ *        bytes actually written on a PRFileDesc.
+ **/
+
+#include "orconfig.h"
+
+#include "lib/log/util_bug.h"
+#include "lib/malloc/malloc.h"
+#include "lib/tls/nss_countbytes.h"
+
+#include <stdlib.h>
+#include <string.h>
+
+#include <prio.h>
+
+/** Boolean: have we initialized this module */
+static bool countbytes_initialized = false;
+
+/** Integer to identity this layer. */
+static PRDescIdentity countbytes_layer_id = PR_INVALID_IO_LAYER;
+
+/** Table of methods for this layer.*/
+static PRIOMethods countbytes_methods;
+
+/** Default close function provided by NSPR.  We use this to help
+ *  implement our own close function.*/
+static PRStatus(*default_close_fn)(PRFileDesc *fd);
+
+static PRStatus countbytes_close_fn(PRFileDesc *fd);
+static PRInt32 countbytes_read_fn(PRFileDesc *fd, void *buf, PRInt32 amount);
+static PRInt32 countbytes_write_fn(PRFileDesc *fd, const void *buf,
+                                   PRInt32 amount);
+static PRInt32 countbytes_writev_fn(PRFileDesc *fd, const PRIOVec *iov,
+                                    PRInt32 size, PRIntervalTime timeout);
+static PRInt32 countbytes_send_fn(PRFileDesc *fd, const void *buf,
+                                  PRInt32 amount, PRIntn flags,
+                                  PRIntervalTime timeout);
+static PRInt32 countbytes_recv_fn(PRFileDesc *fd, void *buf, PRInt32 amount,
+                                  PRIntn flags, PRIntervalTime timeout);
+
+/** Private fields for the byte-counter layer.  We cast this to and from
+ * PRFilePrivate*, which is supposed to be allowed. */
+typedef struct tor_nss_bytecounts_t {
+  uint64_t n_read;
+  uint64_t n_written;
+} tor_nss_bytecounts_t;
+
+/**
+ * Initialize this module, if it is not already initialized.
+ **/
+void
+tor_nss_countbytes_init(void)
+{
+  if (countbytes_initialized)
+    return;
+
+  countbytes_layer_id = PR_GetUniqueIdentity("Tor byte-counting layer");
+  tor_assert(countbytes_layer_id != PR_INVALID_IO_LAYER);
+
+  memcpy(&countbytes_methods, PR_GetDefaultIOMethods(), sizeof(PRIOMethods));
+
+  default_close_fn = countbytes_methods.close;
+  countbytes_methods.close = countbytes_close_fn;
+  countbytes_methods.read = countbytes_read_fn;
+  countbytes_methods.write = countbytes_write_fn;
+  countbytes_methods.writev = countbytes_writev_fn;
+  countbytes_methods.send = countbytes_send_fn;
+  countbytes_methods.recv = countbytes_recv_fn;
+  /* NOTE: We aren't wrapping recvfrom, sendto, or sendfile, since I think
+   * NSS won't be using them for TLS connections. */
+
+  countbytes_initialized = true;
+}
+
+/**
+ * Return the tor_nss_bytecounts_t object for a given IO layer. Asserts that
+ * the IO layer is in fact a layer created by this module.
+ */
+static tor_nss_bytecounts_t *
+get_counts(PRFileDesc *fd)
+{
+  tor_assert(fd->identity == countbytes_layer_id);
+  return (tor_nss_bytecounts_t*) fd->secret;
+}
+
+/** Helper: increment the read-count of an fd by n. */
+#define INC_READ(fd, n) STMT_BEGIN                      \
+    get_counts(fd)->n_read += (n);                      \
+  STMT_END
+
+/** Helper: increment the write-count of an fd by n. */
+#define INC_WRITTEN(fd, n) STMT_BEGIN                      \
+    get_counts(fd)->n_written += (n);                      \
+  STMT_END
+
+/** Implementation for PR_Close: frees the 'secret' field, then passes control
+ * to the default close function */
+static PRStatus
+countbytes_close_fn(PRFileDesc *fd)
+{
+  tor_assert(fd);
+
+  tor_nss_bytecounts_t *counts = (tor_nss_bytecounts_t *)fd->secret;
+  tor_free(counts);
+  fd->secret = NULL;
+
+  return default_close_fn(fd);
+}
+
+/** Implementation for PR_Read: Calls the lower-level read function,
+ * and records what it said. */
+static PRInt32
+countbytes_read_fn(PRFileDesc *fd, void *buf, PRInt32 amount)
+{
+  tor_assert(fd);
+  tor_assert(fd->lower);
+
+  PRInt32 result = (fd->lower->methods->read)(fd->lower, buf, amount);
+  if (result > 0)
+    INC_READ(fd, result);
+  return result;
+}
+/** Implementation for PR_Write: Calls the lower-level write function,
+ * and records what it said. */
+static PRInt32
+countbytes_write_fn(PRFileDesc *fd, const void *buf, PRInt32 amount)
+{
+  tor_assert(fd);
+  tor_assert(fd->lower);
+
+  PRInt32 result = (fd->lower->methods->write)(fd->lower, buf, amount);
+  if (result > 0)
+    INC_WRITTEN(fd, result);
+  return result;
+}
+/** Implementation for PR_Writev: Calls the lower-level writev function,
+ * and records what it said. */
+static PRInt32
+countbytes_writev_fn(PRFileDesc *fd, const PRIOVec *iov,
+                     PRInt32 size, PRIntervalTime timeout)
+{
+  tor_assert(fd);
+  tor_assert(fd->lower);
+
+  PRInt32 result = (fd->lower->methods->writev)(fd->lower, iov, size, timeout);
+  if (result > 0)
+    INC_WRITTEN(fd, result);
+  return result;
+}
+/** Implementation for PR_Send: Calls the lower-level send function,
+ * and records what it said. */
+static PRInt32
+countbytes_send_fn(PRFileDesc *fd, const void *buf,
+                   PRInt32 amount, PRIntn flags, PRIntervalTime timeout)
+{
+  tor_assert(fd);
+  tor_assert(fd->lower);
+
+  PRInt32 result = (fd->lower->methods->send)(fd->lower, buf, amount, flags,
+                                              timeout);
+  if (result > 0)
+    INC_WRITTEN(fd, result);
+  return result;
+}
+/** Implementation for PR_Recv: Calls the lower-level recv function,
+ * and records what it said. */
+static PRInt32
+countbytes_recv_fn(PRFileDesc *fd, void *buf, PRInt32 amount,
+                                  PRIntn flags, PRIntervalTime timeout)
+{
+  tor_assert(fd);
+  tor_assert(fd->lower);
+
+  PRInt32 result = (fd->lower->methods->recv)(fd->lower, buf, amount, flags,
+                                              timeout);
+  if (result > 0)
+    INC_READ(fd, result);
+  return result;
+}
+
+/**
+ * Wrap a PRFileDesc from NSPR with a new PRFileDesc that will count the
+ * total number of bytes read and written.  Return the new PRFileDesc.
+ *
+ * This function takes ownership of its input.
+ */
+PRFileDesc *
+tor_wrap_prfiledesc_with_byte_counter(PRFileDesc *stack)
+{
+  if (BUG(! countbytes_initialized)) {
+    tor_nss_countbytes_init();
+  }
+
+  tor_nss_bytecounts_t *bytecounts = tor_malloc_zero(sizeof(*bytecounts));
+
+  PRFileDesc *newfd = PR_CreateIOLayerStub(countbytes_layer_id,
+                                           &countbytes_methods);
+  tor_assert(newfd);
+  newfd->secret = (PRFilePrivate *)bytecounts;
+
+  /* This does some complicated messing around with the headers of these
+     objects; see the NSPR documentation for more. The upshot is that
+     after PushIOLayer, "stack" will be the head of the stack.
+  */
+  PRStatus status = PR_PushIOLayer(stack, PR_TOP_IO_LAYER, newfd);
+  tor_assert(status == PR_SUCCESS);
+
+  return stack;
+}
+
+/**
+ * Given a PRFileDesc returned by tor_wrap_prfiledesc_with_byte_counter(),
+ * or another PRFileDesc wrapping that PRFileDesc, set the provided
+ * pointers to the number of bytes read and written on the descriptor since
+ * it was created.
+ *
+ * Return 0 on success, -1 on failure.
+ */
+int
+tor_get_prfiledesc_byte_counts(PRFileDesc *fd,
+                               uint64_t *n_read_out,
+                               uint64_t *n_written_out)
+{
+  if (BUG(! countbytes_initialized)) {
+    tor_nss_countbytes_init();
+  }
+
+  tor_assert(fd);
+  PRFileDesc *bclayer = PR_GetIdentitiesLayer(fd, countbytes_layer_id);
+  if (BUG(bclayer == NULL))
+    return -1;
+
+  tor_nss_bytecounts_t *counts = get_counts(bclayer);
+
+  *n_read_out = counts->n_read;
+  *n_written_out = counts->n_written;
+
+  return 0;
+}
diff --git a/src/lib/tls/nss_countbytes.h b/src/lib/tls/nss_countbytes.h
new file mode 100644
index 000000000..f26280edf
--- /dev/null
+++ b/src/lib/tls/nss_countbytes.h
@@ -0,0 +1,25 @@
+/* Copyright 2018, The Tor Project, Inc. */
+/* See LICENSE for licensing information */
+
+/**
+ * \file nss_countbytes.h
+ * \brief Header for nss_countbytes.c, which lets us count the number of
+ *        bytes actually written on a PRFileDesc.
+ **/
+
+#ifndef TOR_NSS_COUNTBYTES_H
+#define TOR_NSS_COUNTBYTES_H
+
+#include "lib/cc/torint.h"
+
+void tor_nss_countbytes_init(void);
+
+struct PRFileDesc;
+struct PRFileDesc *tor_wrap_prfiledesc_with_byte_counter(
+                                               struct PRFileDesc *stack);
+
+int tor_get_prfiledesc_byte_counts(struct PRFileDesc *fd,
+                                   uint64_t *n_read_out,
+                                   uint64_t *n_written_out);
+
+#endif
diff --git a/src/lib/tls/tortls_nss.c b/src/lib/tls/tortls_nss.c
index 53adfedf3..0944c57a3 100644
--- a/src/lib/tls/tortls_nss.c
+++ b/src/lib/tls/tortls_nss.c
@@ -31,11 +31,12 @@
 #include "lib/tls/tortls.h"
 #include "lib/tls/tortls_st.h"
 #include "lib/tls/tortls_internal.h"
+#include "lib/tls/nss_countbytes.h"
 #include "lib/log/util_bug.h"
 
 DISABLE_GCC_WARNING(strict-prototypes)
 #include <prio.h>
-// For access to raw sockets.
+// For access to rar sockets.
 #include <private/pprio.h>
 #include <ssl.h>
 #include <sslt.h>
@@ -158,6 +159,8 @@ tor_tls_context_new(crypto_pk_t *identity,
   SECStatus s;
   tor_assert(identity);
 
+  tor_tls_init();
+
   tor_tls_context_t *ctx = tor_malloc_zero(sizeof(tor_tls_context_t));
   ctx->refcnt = 1;
 
@@ -320,7 +323,7 @@ tor_tls_get_state_description(tor_tls_t *tls, char *buf, size_t sz)
 void
 tor_tls_init(void)
 {
-  /* We don't have any global setup to do yet, but that will change */
+  tor_nss_countbytes_init();
 }
 
 void
@@ -373,7 +376,11 @@ tor_tls_new(tor_socket_t sock, int is_server)
   if (!tcp)
     return NULL;
 
-  PRFileDesc *ssl = SSL_ImportFD(ctx->ctx, tcp);
+  PRFileDesc *count = tor_wrap_prfiledesc_with_byte_counter(tcp);
+  if (! count)
+    return NULL;
+
+  PRFileDesc *ssl = SSL_ImportFD(ctx->ctx, count);
   if (!ssl) {
     PR_Close(tcp);
     return NULL;
@@ -465,7 +472,6 @@ tor_tls_read, (tor_tls_t *tls, char *cp, size_t len))
   PRInt32 rv = PR_Read(tls->ssl, cp, (int)len);
   // log_debug(LD_NET, "PR_Read(%zu) returned %d", n, (int)rv);
   if (rv > 0) {
-    tls->n_read_since_last_check += rv;
     return rv;
   }
   if (rv == 0)
@@ -489,7 +495,6 @@ tor_tls_write(tor_tls_t *tls, const char *cp, size_t n)
   PRInt32 rv = PR_Write(tls->ssl, cp, (int)n);
   // log_debug(LD_NET, "PR_Write(%zu) returned %d", n, (int)rv);
   if (rv > 0) {
-    tls->n_written_since_last_check += rv;
     return rv;
   }
   if (rv == 0)
@@ -579,13 +584,17 @@ tor_tls_get_n_raw_bytes(tor_tls_t *tls,
   tor_assert(tls);
   tor_assert(n_read);
   tor_assert(n_written);
-  /* XXXX We don't curently have a way to measure this information correctly
-   * in NSS; we could do that with a PRIO layer, but it'll take a little
-   * coding.  For now, we just track the number of bytes sent _in_ the TLS
-   * stream.  Doing this will make our rate-limiting slightly inaccurate. */
-  *n_read = tls->n_read_since_last_check;
-  *n_written = tls->n_written_since_last_check;
-  tls->n_read_since_last_check = tls->n_written_since_last_check = 0;
+  uint64_t r, w;
+  if (tor_get_prfiledesc_byte_counts(tls->ssl, &r, &w) < 0) {
+    *n_read = *n_written = 0;
+    return;
+  }
+
+  *n_read = (size_t)(r - tls->last_read_count);
+  *n_written = (size_t)(w - tls->last_write_count);
+
+  tls->last_read_count = r;
+  tls->last_write_count = w;
 }
 
 int
diff --git a/src/lib/tls/tortls_st.h b/src/lib/tls/tortls_st.h
index a1b59a37a..549443a4e 100644
--- a/src/lib/tls/tortls_st.h
+++ b/src/lib/tls/tortls_st.h
@@ -66,8 +66,9 @@ struct tor_tls_t {
   void *callback_arg;
 #endif
 #ifdef ENABLE_NSS
-  size_t n_read_since_last_check;
-  size_t n_written_since_last_check;
+  /** Last values retried from tor_get_prfiledesc_byte_counts(). */
+  uint64_t last_write_count;
+  uint64_t last_read_count;
 #endif
 };
 





More information about the tor-commits mailing list