[tor-commits] [pytorctl/master] use local sessions

mikeperry at torproject.org mikeperry at torproject.org
Tue Sep 6 23:00:12 UTC 2011


commit e99394c46b30ed73ffbf02d5c6a39bf9b04416a2
Author: aagbsn <aagbsn at extc.org>
Date:   Wed Aug 31 16:23:42 2011 -0700

    use local sessions
    
    see:
    http://www.sqlalchemy.org/docs/orm/session.html#lifespan-of-a-contextual-session
      "This has the effect such that each web request starts fresh with a brand new
      session, and is the most definitive approach to closing out a request."
---
 SQLSupport.py |  186 +++++++++++++++++++++++++++++++++++++--------------------
 1 files changed, 120 insertions(+), 66 deletions(-)

diff --git a/SQLSupport.py b/SQLSupport.py
index efee35b..b7c9a82 100644
--- a/SQLSupport.py
+++ b/SQLSupport.py
@@ -40,6 +40,13 @@ MIN_RATIO=0.5
 
 NO_FPE=2**-50
 
+#################### Session Usage ###############
+# What is all this l_session madness? See:                                                                
+# http://www.sqlalchemy.org/docs/orm/session.html#lifespan-of-a-contextual-session                        
+#   "This has the effect such that each web request starts fresh with                                     
+#   a brand new session, and is the most definitive approach to closing                                   
+#   out a request." 
+
 #################### Model #######################
 
 # In elixir, the session (DB connection) is a property of the model..
@@ -257,6 +264,7 @@ class RouterStats(Entity):
   filt_sbw_ratio = Field(Float)
 
   def _compute_stats_relation(stats_clause):
+    l_session = tc_session()
     for rs in RouterStats.query.\
                    filter(stats_clause).\
                    options(eagerload_all('router.circuits.extensions')).\
@@ -295,12 +303,15 @@ class RouterStats(Entity):
       if rs.circ_try_to+rs.circ_try_from > 0:
         rs.circ_bi_rate = (1.0*rs.circ_fail_to+rs.circ_fail_from)/(rs.circ_try_to+rs.circ_try_from)
 
-      tc_session.add(rs)
-    tc_session.commit()
+      l_session.add(rs)
+    l_session.commit()
+    tc_session.remove()
   _compute_stats_relation = Callable(_compute_stats_relation)
 
+
   def _compute_stats_query(stats_clause):
     tc_session.expunge_all()
+    l_session = tc_session()
     # http://www.sqlalchemy.org/docs/04/sqlexpression.html#sql_update
     to_s = select([func.count(Extension.id)], 
         and_(stats_clause, Extension.table.c.to_node_idhex
@@ -374,10 +385,12 @@ class RouterStats(Entity):
             tot_var += (s.bandwidth()-rs.sbw)*(s.bandwidth()-rs.sbw)
         tot_var /= s_cnt
         rs.sbw_dev = math.sqrt(tot_var)
-      tc_session.add(rs)
-    tc_session.commit()
+      l_session.add(rs)
+    l_session.commit()
+    tc_session.remove()
   _compute_stats_query = Callable(_compute_stats_query)
 
+
   def _compute_stats(stats_clause):
     RouterStats._compute_stats_query(stats_clause)
     #RouterStats._compute_stats_relation(stats_clause)
@@ -385,6 +398,7 @@ class RouterStats(Entity):
 
   def _compute_ranks():
     tc_session.expunge_all()
+    l_session = tc_session()
     min_r = select([func.min(BwHistory.rank)],
         BwHistory.table.c.router_idhex
             == RouterStats.table.c.router_idhex).as_scalar()
@@ -422,11 +436,13 @@ class RouterStats(Entity):
        {RouterStats.table.c.percentile:
             (100.0*RouterStats.table.c.avg_rank)/max_avg_rank}).execute()
 
-    tc_session.commit()
+    l_session.commit()
+    tc_session.remove()
   _compute_ranks = Callable(_compute_ranks)
 
   def _compute_ratios(stats_clause):
     tc_session.expunge_all()
+    l_session = tc_session()
     avg_from_rate = select([alias(
         select([func.avg(RouterStats.circ_from_rate)],
                            stats_clause)
@@ -459,10 +475,12 @@ class RouterStats(Entity):
          avg_ext/RouterStats.table.c.avg_first_ext,
         RouterStats.table.c.sbw_ratio:
          RouterStats.table.c.sbw/avg_sbw}).execute()
-    tc_session.commit()
+    l_session.commit()
+    tc_session.remove()
   _compute_ratios = Callable(_compute_ratios)
 
   def _compute_filtered_relational(min_ratio, stats_clause, filter_clause):
+    l_session = tc_session()
     badrouters = RouterStats.query.filter(stats_clause).filter(filter_clause).\
                    filter(RouterStats.sbw_ratio < min_ratio).all()
 
@@ -487,18 +505,19 @@ class RouterStats(Entity):
 
       if sbw_cnt: rs.filt_sbw = tot_sbw/sbw_cnt
       else: rs.filt_sbw = None
-      tc_session.add(rs)
+      l_session.add(rs)
     if sqlalchemy.__version__ < "0.5.0":
       avg_sbw = RouterStats.query.filter(stats_clause).avg(RouterStats.filt_sbw)
     else:
-      avg_sbw = tc_session.query(func.avg(RouterStats.filt_sbw)).filter(stats_clause).scalar()
+      avg_sbw = l_session.query(func.avg(RouterStats.filt_sbw)).filter(stats_clause).scalar()
     for rs in RouterStats.query.filter(stats_clause).all():
       if type(rs.filt_sbw) == float and avg_sbw:
         rs.filt_sbw_ratio = rs.filt_sbw/avg_sbw
       else:
         rs.filt_sbw_ratio = None
-      tc_session.add(rs)
-    tc_session.commit()
+      l_session.add(rs)
+    l_session.commit()
+    tc_session.remove()
   _compute_filtered_relational = Callable(_compute_filtered_relational)
 
   def _compute_filtered_ratios(min_ratio, stats_clause, filter_clause):
@@ -509,17 +528,20 @@ class RouterStats(Entity):
 
   def reset():
     tc_session.expunge_all()
+    l_session = tc_session()
     RouterStats.table.drop()
     RouterStats.table.create()
     for r in Router.query.all():
       rs = RouterStats()
       rs.router = r
       r.stats = rs
-      tc_session.add(r)
-    tc_session.commit()
+      l_session.add(r)
+    l_session.commit()
+    tc_session.remove()
   reset = Callable(reset)
 
   def compute(pct_low=0, pct_high=100, stat_clause=None, filter_clause=None):
+    l_session = tc_session()
     pct_clause = and_(RouterStats.percentile >= pct_low, 
                          RouterStats.percentile < pct_high)
     if stat_clause:
@@ -532,10 +554,12 @@ class RouterStats(Entity):
     RouterStats._compute_stats(stat_clause)
     RouterStats._compute_ratios(stat_clause)
     RouterStats._compute_filtered_ratios(MIN_RATIO, stat_clause, filter_clause)
-    tc_session.commit()
+    l_session.commit()
+    tc_session.remove()
   compute = Callable(compute)  
 
   def write_stats(f, pct_low=0, pct_high=100, order_by=None, recompute=False, stat_clause=None, filter_clause=None, disp_clause=None):
+    l_session = tc_session()
 
     if not order_by:
       order_by=RouterStats.avg_first_ext
@@ -557,14 +581,16 @@ class RouterStats(Entity):
       filt_sbw = RouterStats.query.filter(pct_clause).filter(stat_clause).avg(RouterStats.filt_sbw)
       percentile = RouterStats.query.filter(pct_clause).filter(stat_clause).avg(RouterStats.percentile)
     else:
-      circ_from_rate = tc_session.query(func.avg(RouterStats.circ_from_rate)).filter(pct_clause).filter(stat_clause).scalar()
-      circ_to_rate = tc_session.query(func.avg(RouterStats.circ_to_rate)).filter(pct_clause).filter(stat_clause).scalar()
-      circ_bi_rate = tc_session.query(func.avg(RouterStats.circ_bi_rate)).filter(pct_clause).filter(stat_clause).scalar()
+      circ_from_rate = l_session.query(func.avg(RouterStats.circ_from_rate)).filter(pct_clause).filter(stat_clause).scalar()
+      circ_to_rate = l_session.query(func.avg(RouterStats.circ_to_rate)).filter(pct_clause).filter(stat_clause).scalar()
+      circ_bi_rate = l_session.query(func.avg(RouterStats.circ_bi_rate)).filter(pct_clause).filter(stat_clause).scalar()
       
-      avg_first_ext = tc_session.query(func.avg(RouterStats.avg_first_ext)).filter(pct_clause).filter(stat_clause).scalar()
-      sbw = tc_session.query(func.avg(RouterStats.sbw)).filter(pct_clause).filter(stat_clause).scalar()
-      filt_sbw = tc_session.query(func.avg(RouterStats.filt_sbw)).filter(pct_clause).filter(stat_clause).scalar()
-      percentile = tc_session.query(func.avg(RouterStats.percentile)).filter(pct_clause).filter(stat_clause).scalar()
+      avg_first_ext = l_session.query(func.avg(RouterStats.avg_first_ext)).filter(pct_clause).filter(stat_clause).scalar()
+      sbw = l_session.query(func.avg(RouterStats.sbw)).filter(pct_clause).filter(stat_clause).scalar()
+      filt_sbw = l_session.query(func.avg(RouterStats.filt_sbw)).filter(pct_clause).filter(stat_clause).scalar()
+      percentile = l_session.query(func.avg(RouterStats.percentile)).filter(pct_clause).filter(stat_clause).scalar()
+
+    tc_session.remove()
 
     def cvt(a,b,c=1):
       if type(a) == float: return round(a/c,b)
@@ -617,6 +643,7 @@ class RouterStats(Entity):
   
 
   def write_bws(f, pct_low=0, pct_high=100, order_by=None, recompute=False, stat_clause=None, filter_clause=None, disp_clause=None):
+    l_session = tc_session()
     if not order_by:
       order_by=RouterStats.avg_first_ext
 
@@ -631,8 +658,8 @@ class RouterStats(Entity):
       sbw = RouterStats.query.filter(pct_clause).filter(stat_clause).avg(RouterStats.sbw)
       filt_sbw = RouterStats.query.filter(pct_clause).filter(stat_clause).avg(RouterStats.filt_sbw)
     else:
-      sbw = tc_session.query(func.avg(RouterStats.sbw)).filter(pct_clause).filter(stat_clause).scalar()
-      filt_sbw = tc_session.query(func.avg(RouterStats.filt_sbw)).filter(pct_clause).filter(stat_clause).scalar()
+      sbw = l_session.query(func.avg(RouterStats.sbw)).filter(pct_clause).filter(stat_clause).scalar()
+      filt_sbw = l_session.query(func.avg(RouterStats.filt_sbw)).filter(pct_clause).filter(stat_clause).scalar()
 
     f.write(str(int(time.time()))+"\n")
 
@@ -651,6 +678,7 @@ class RouterStats(Entity):
       f.write(" ns_bw="+str(int(cvt(s.avg_bw,0)))+"\n")
 
     f.flush()
+    tc_session.remove()
   write_bws = Callable(write_bws)  
     
 
@@ -658,6 +686,7 @@ class RouterStats(Entity):
 
 #################### Model Support ################
 def reset_all():
+  l_session = tc_session()
   plog("WARN", "SQLSupport.reset_all() called. See SQLSupport.py for details")
   # XXX: We still have a memory leak somewhere in here
   # Current suspects are sqlite, python-sqlite, or sqlalchemy misuse...
@@ -677,9 +706,9 @@ def reset_all():
     r.detached_streams = []
     r.bw_history = [] 
     r.stats = None
-    tc_session.add(r)
+    l_session.add(r)
 
-  tc_session.commit()
+  l_session.commit()
   tc_session.expunge_all()
 
   # XXX: WARNING!
@@ -705,13 +734,14 @@ def reset_all():
   Stream.table.create() 
   Circuit.table.create()
 
-  tc_session.commit()
+  l_session.commit()
 
   #for r in Router.query.all():
   #  if len(r.bw_history) or len(r.circuits) or len(r.streams) or r.stats:
   #    plog("WARN", "Router still has dropped data!")
 
   plog("NOTICE", "Reset all SQL stats")
+  tc_session.remove()
 
 def refresh_all():
   # necessary to keep all sessions synchronized
@@ -737,6 +767,7 @@ class ConsensusTrackerListener(TorCtl.DualEventListener):
 
   # TODO: What about non-running routers and uptime information?
   def _update_rank_history(self, idlist):
+    l_session = tc_session()
     plog("INFO", "Consensus change... Updating rank history")
     for idhex in idlist:
       if idhex not in self.consensus.routers: continue
@@ -748,15 +779,17 @@ class ConsensusTrackerListener(TorCtl.DualEventListener):
         bwh = BwHistory(router=r, rank=rc.list_rank, bw=rc.bw,
                         desc_bw=rc.desc_bw, pub_time=r.published)
         r.bw_history.append(bwh)
-        #tc_session.add(bwh)
-        tc_session.add(r)
+        #l_session.add(bwh)
+        l_session.add(r)
       except sqlalchemy.orm.exc.NoResultFound:
         plog("WARN", "No descriptor found for consenus router "+str(idhex))
 
     plog("INFO", "Consensus history updated.")
-    tc_session.commit()
+    l_session.commit()
+    tc_session.remove()
 
   def _update_db(self, idlist):
+    l_session = tc_session()
     # FIXME: It is tempting to delay this as well, but we need
     # this info to be present immediately for circuit construction...
     plog("INFO", "Consensus change... Updating db")
@@ -769,9 +802,10 @@ class ConsensusTrackerListener(TorCtl.DualEventListener):
           continue
         if not r: r = Router()
         r.from_router(rc)
-        tc_session.add(r)
+        l_session.add(r)
     plog("INFO", "Consensus db updated")
-    tc_session.commit()
+    l_session.commit()
+    tc_session.remove()
     # testing
     #refresh_all() # Too many sessions, don't trust commit()
 
@@ -786,6 +820,7 @@ class ConsensusTrackerListener(TorCtl.DualEventListener):
     TorCtl.DualEventListener.set_parent(self, parent_handler)
 
   def heartbeat_event(self, e):
+    l_session = tc_session()
     # This sketchiness is to ensure we have an accurate history
     # of each router's rank+bandwidth for the entire duration of the run..
     if e.state == EVENT_STATE.PRELISTEN:
@@ -798,8 +833,9 @@ class ConsensusTrackerListener(TorCtl.DualEventListener):
                     orhash="000000000000000000000000000", 
                     nickname="!!TorClient", 
                     published=datetime.datetime.utcnow())
-          tc_session.add(OP)
-          tc_session.commit()
+          l_session.add(OP)
+          l_session.commit()
+          tc_session.remove()
         self.update_consensus()
       # XXX: This hack exists because update_rank_history is expensive.
       # However, even if we delay it till the end of the consensus update, 
@@ -853,7 +889,9 @@ class CircuitListener(TorCtl.PreEventListener):
       self.track_parent = False
 
   def circ_status_event(self, c):
+    l_session = tc_session()
     if self.track_parent and c.circ_id not in self.parent_handler.circuits:
+      tc_session.remove()
       return # Ignore circuits that aren't ours
     # TODO: Hrmm, consider making this sane in TorCtl.
     if c.reason: lreason = c.reason
@@ -883,13 +921,15 @@ class CircuitListener(TorCtl.PreEventListener):
             return
           circ.routers.append(rq) 
           #rq.circuits.append(circ) # done automagically?
-          #tc_session.add(rq)
-      tc_session.add(circ)
-      tc_session.commit()
+          #l_session.add(rq)
+      l_session.add(circ)
+      l_session.commit()
     elif c.status == "EXTENDED":
       circ = Circuit.query.options(eagerload('extensions')).filter_by(
                        circ_id = c.circ_id).first()
-      if not circ: return # Skip circuits from before we came online
+      if not circ: 
+        tc_session.remove()
+        return # Skip circuits from before we came online
 
       e = Extension(circ=circ, hop=len(c.path)-1, time=c.arrived_at)
 
@@ -908,17 +948,19 @@ class CircuitListener(TorCtl.PreEventListener):
         # FIXME: Eager load here?
         circ.routers.append(e.to_node)
         e.to_node.circuits.append(circ)
-        tc_session.add(e.to_node)
+        l_session.add(e.to_node)
  
       e.delta = c.arrived_at - circ.last_extend
       circ.last_extend = c.arrived_at
       circ.extensions.append(e)
-      tc_session.add(e)
-      tc_session.add(circ)
-      tc_session.commit()
+      l_session.add(e)
+      l_session.add(circ)
+      l_session.commit()
     elif c.status == "FAILED":
       circ = Circuit.query.filter_by(circ_id = c.circ_id).first()
-      if not circ: return # Skip circuits from before we came online
+      if not circ: 
+        tc_session.remove()
+        return # Skip circuits from before we came online
         
       circ.expunge()
       if isinstance(circ, BuiltCircuit):
@@ -956,14 +998,16 @@ class CircuitListener(TorCtl.PreEventListener):
         e.reason = reason
         circ.extensions.append(e)
         circ.fail_time = c.arrived_at
-        tc_session.add(e)
+        l_session.add(e)
 
-      tc_session.add(circ)
-      tc_session.commit()
+      l_session.add(circ)
+      l_session.commit()
     elif c.status == "BUILT":
       circ = Circuit.query.filter_by(
                      circ_id = c.circ_id).first()
-      if not circ: return # Skip circuits from before we came online
+      if not circ:
+        tc_session.remove()
+        return # Skip circuits from before we came online
 
       circ.expunge()
       # Convert to built circuit
@@ -973,8 +1017,8 @@ class CircuitListener(TorCtl.PreEventListener):
       
       circ.built_time = c.arrived_at
       circ.tot_delta = c.arrived_at - circ.launch_time
-      tc_session.add(circ)
-      tc_session.commit()
+      l_session.add(circ)
+      l_session.commit()
     elif c.status == "CLOSED":
       circ = BuiltCircuit.query.filter_by(circ_id = c.circ_id).first()
       if circ:
@@ -992,20 +1036,24 @@ class CircuitListener(TorCtl.PreEventListener):
           circ = DestroyedCircuit.query.filter_by(id=circ.id).one()
           circ.destroy_reason = reason
           circ.destroy_time = c.arrived_at
-        tc_session.add(circ)
-        tc_session.commit()
+        l_session.add(circ)
+        l_session.commit()
+    tc_session.remove()
 
 class StreamListener(CircuitListener):
   def stream_bw_event(self, s):
+    l_session = tc_session()
     strm = Stream.query.filter_by(strm_id = s.strm_id).first()
     if strm and strm.start_time and strm.start_time < s.arrived_at:
       plog("DEBUG", "Got stream bw: "+str(s.strm_id))
       strm.tot_read_bytes += s.bytes_read
       strm.tot_write_bytes += s.bytes_written
-      tc_session.add(strm)
-      tc_session.commit()
+      l_session.add(strm)
+      l_session.commit()
+      tc_session.remove()
  
   def stream_status_event(self, s):
+    l_session = tc_session()
     if s.reason: lreason = s.reason
     else: lreason = "NONE"
     if s.remote_reason: rreason = s.remote_reason
@@ -1015,8 +1063,9 @@ class StreamListener(CircuitListener):
       strm = Stream(strm_id=s.strm_id, tgt_host=s.target_host, 
                     tgt_port=s.target_port, init_status=s.status,
                     tot_read_bytes=0, tot_write_bytes=0)
-      tc_session.add(strm)
-      tc_session.commit()
+      l_session.add(strm)
+      l_session.commit()
+      tc_session.remove()
       return
 
     strm = Stream.query.filter_by(strm_id = s.strm_id).first()
@@ -1025,7 +1074,8 @@ class StreamListener(CircuitListener):
            self.parent_handler.streams[s.strm_id].ignored):
       if strm:
         strm.delete()
-        tc_session.commit()
+        l_session.commit()
+        tc_session.remove()
       return # Ignore streams that aren't ours
 
     if not strm: 
@@ -1040,7 +1090,8 @@ class StreamListener(CircuitListener):
       if not strm.circuit:
         plog("NOTICE", "Ignoring prior stream "+str(strm.strm_id)+" with old circuit "+str(s.circ_id))
         strm.delete()
-        tc_session.commit()
+        l_session.commit()
+        tc_session.remove()
         return
     else:
       circ = None
@@ -1066,19 +1117,19 @@ class StreamListener(CircuitListener):
       for r in strm.circuit.routers: 
         plog("DEBUG", "Added router "+r.idhex+" to stream "+str(s.strm_id))
         r.streams.append(strm)
-        tc_session.add(r)
-      tc_session.add(strm)
-      tc_session.commit()
+        l_session.add(r)
+      l_session.add(strm)
+      l_session.commit()
     elif s.status == "DETACHED":
       for r in strm.circuit.routers:
         r.detached_streams.append(strm)
-        tc_session.add(r)
+        l_session.add(r)
       #strm.detached_circuits.append(strm.circuit)
       strm.circuit.detached_streams.append(strm)
       strm.circuit.streams.remove(strm)
       strm.circuit = None
-      tc_session.add(strm)
-      tc_session.commit()
+      l_session.add(strm)
+      l_session.commit()
     elif s.status == "FAILED":
       strm.expunge()
       # Convert to destroyed circuit
@@ -1087,8 +1138,8 @@ class StreamListener(CircuitListener):
       strm = FailedStream.query.filter_by(id=strm.id).one()
       strm.fail_time = s.arrived_at
       strm.fail_reason = reason
-      tc_session.add(strm)
-      tc_session.commit()
+      l_session.add(strm)
+      l_session.commit()
     elif s.status == "CLOSED":
       if isinstance(strm, FailedStream):
         strm.close_reason = reason
@@ -1110,8 +1161,9 @@ class StreamListener(CircuitListener):
           strm.end_time = s.arrived_at
           plog("DEBUG", "Stream "+str(strm.strm_id)+" xmitted "+str(strm.tot_bytes()))
         strm.close_reason = reason
-      tc_session.add(strm)
-      tc_session.commit()
+      l_session.add(strm)
+      l_session.commit()
+    tc_session.remove()
 
 def run_example(host, port):
   """ Example of basic TorCtl usage. See PathSupport for more advanced
@@ -1120,11 +1172,13 @@ def run_example(host, port):
   print "host is %s:%d"%(host,port)
   setup_db("sqlite:///torflow.sqlite", echo=False)
 
-  #print tc_session.query(((func.count(Extension.id)))).filter(and_(FailedExtension.table.c.row_type=='extension', FailedExtension.table.c.from_node_idhex == "7CAA2F5F998053EF5D2E622563DEB4A6175E49AC")).one()
+  #l_session = tc_session()
+  #print l_session.query(((func.count(Extension.id)))).filter(and_(FailedExtension.table.c.row_type=='extension', FailedExtension.table.c.from_node_idhex == "7CAA2F5F998053EF5D2E622563DEB4A6175E49AC")).one()
   #return
   #for e in Extension.query.filter(FailedExtension.table.c.row_type=='extension').all():
   #  if e.from_node: print "From: "+e.from_node.idhex+" "+e.from_node.nickname
   #  if e.to_node: print "To: "+e.to_node.idhex+" "+e.to_node.nickname
+  #tc_session.remove()
   #return
 
   s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)



More information about the tor-commits mailing list