[tests] Add network_thread_ utility functions.

Add network thread_start(), network_thread_running() and
network_thread_join() utility functions in mininode.py and use
network_thread_running() in network thread assertions.
This commit is contained in:
John Newbery 2017-12-08 10:50:24 -05:00
parent f60b4ad579
commit 5fc6e71d19

View File

@ -18,7 +18,7 @@ import logging
import socket import socket
import struct import struct
import sys import sys
from threading import RLock, Thread import threading
from test_framework.messages import * from test_framework.messages import *
from test_framework.util import wait_until from test_framework.util import wait_until
@ -397,9 +397,12 @@ mininode_socket_map = dict()
# and whenever adding anything to the send buffer (in send_message()). This # and whenever adding anything to the send buffer (in send_message()). This
# lock should be acquired in the thread running the test logic to synchronize # lock should be acquired in the thread running the test logic to synchronize
# access to any data shared with the P2PInterface or P2PConnection. # access to any data shared with the P2PInterface or P2PConnection.
mininode_lock = RLock() mininode_lock = threading.RLock()
class NetworkThread(threading.Thread):
def __init__(self):
super().__init__(name="NetworkThread")
class NetworkThread(Thread):
def run(self): def run(self):
while mininode_socket_map: while mininode_socket_map:
# We check for whether to disconnect outside of the asyncore # We check for whether to disconnect outside of the asyncore
@ -412,3 +415,21 @@ class NetworkThread(Thread):
[obj.handle_close() for obj in disconnected] [obj.handle_close() for obj in disconnected]
asyncore.loop(0.1, use_poll=True, map=mininode_socket_map, count=1) asyncore.loop(0.1, use_poll=True, map=mininode_socket_map, count=1)
logger.debug("Network thread closing") logger.debug("Network thread closing")
def network_thread_start():
"""Start the network thread."""
NetworkThread().start()
def network_thread_running():
"""Return whether the network thread is running."""
return any([thread.name == "NetworkThread" for thread in threading.enumerate()])
def network_thread_join(timeout=10):
"""Wait timeout seconds for the network thread to terminate.
Throw if the network thread doesn't terminate in timeout seconds."""
network_threads = [thread for thread in threading.enumerate() if thread.name == "NetworkThread"]
assert len(network_threads) <= 1
for thread in network_threads:
thread.join(timeout)
assert not thread.is_alive()