Source code for sleap.message

"""
Module with classes for sending and receiving messages between processes.

These use ZMQ pub/sub sockets.

Most of the time you'll want the PairedSender and PairedReceiver.
These support a "handshake" to confirm connection. Without an initial
handshake there's a good chance early messages will be dropped.

Each message is either dictionary or dictionary + numpy ndarray.
"""
import attr
import jsonpickle
import numpy as np
import time
import zmq

from typing import Any, Callable, List, Optional, Text


[docs]@attr.s(auto_attribs=True) class BaseMessageParticipant: """Base class for simple Sender and Receiver.""" address: Text = "tcp://127.0.0.1:9001" context: Optional[zmq.Context] = None _socket: Optional[zmq.Socket] = None def __attrs_post_init__(self): if self.context is None: self._owns_context = True self.context = zmq.Context() else: self._owns_context = False def __del__(self): if self._owns_context and self.context is not None: self.context.term()
[docs]@attr.s(auto_attribs=True) class Receiver(BaseMessageParticipant): """Receives messages from corresponding Sender.""" _message_queue: List[Any] = attr.ib(factory=list) def setup(self): self._socket = self.context.socket(zmq.SUB) self._socket.subscribe("") self._socket.bind(self.address) def __del__(self): if self._socket is not None: self._socket.unbind(self._socket.LAST_ENDPOINT) self._socket.close() self._socket = None
[docs] def push_back_message(self, message): """Act like we didn't receive this message yet.""" self._message_queue.append(message)
def _recv(self, flags=0, copy=True, track=False): json_message = self._socket.recv_json(flags=flags) if "dtype" in json_message and "shape" in json_message: msg = self._socket.recv(flags=flags, copy=copy, track=track) buf = memoryview(msg) A = np.frombuffer(buf, dtype=json_message["dtype"]).reshape( json_message["shape"] ) json_message["ndarray"] = A return json_message
[docs] def check_message(self, timeout: int = 10, fresh: bool = False) -> Any: """Attempt to receive a single message.""" if self._message_queue and not fresh: return self._message_queue.pop(0) if self._socket is None: self.setup() if self._socket and self._socket.poll(timeout, zmq.POLLIN): return self._recv() else: return None
[docs] def check_messages(self, timeout: int = 10, times_to_check: int = 10) -> List[dict]: """ Attempt to receive multiple messages. This method allows us to keep up with the messages by getting multiple messages that have been sent since the last check. It keeps checking until limit is reached *or* we check without getting any messages back. """ messages = [] # keep looping until we don't receive a message or have checked enough times while True: this_message = self.check_message(timeout) # if we didn't get a message, we're done checking if this_message is None: return messages # we got a message so add it to list messages.append(this_message) # if we've checked enough times, we're done checking if times_to_check <= 0: return messages # count down the number of times to check for messages times_to_check -= 1
[docs]@attr.s(auto_attribs=True) class Sender(BaseMessageParticipant): """Publishes messages to corresponding Receiver.""" def setup(self): self._socket = self.context.socket(zmq.PUB) self._socket.connect(self.address) def __del__(self): self._socket.setsockopt(zmq.LINGER, 0) self._socket.close() super().__del__()
[docs] def send_dict(self, data: dict): """Sends dictionary.""" if self._socket is None: self.setup() self._socket.send_json(data)
[docs] def send_array( self, header_data: dict, A: np.ndarray, flags=0, copy=True, track=False ): """Sends dictionary + numpy ndarray.""" if self._socket is None: self.setup() header_data["dtype"] = str(A.dtype) header_data["shape"] = A.shape self._socket.send_json(header_data, flags | zmq.SNDMORE) return self._socket.send(A, flags, copy=copy, track=track)
@attr.s(auto_attribs=True) class PairedMessageParticipant: sender_address: Text receiver_address: Text context: Optional[zmq.Context] = None @classmethod def from_tcp_ports(cls, send_port, rec_port): sender_address = f"tcp://127.0.0.1:{send_port}" receiver_address = f"tcp://127.0.0.1:{rec_port}" return cls(sender_address=sender_address, receiver_address=receiver_address) def setup(self): self._sender = Sender(address=self.sender_address, context=self.context) self._receiver = Receiver(address=self.receiver_address, context=self.context) self._sender.setup() self._receiver.setup() def close(self): if hasattr(self, "_sender"): del self._sender if hasattr(self, "_receiver"): del self._receiver @attr.s(auto_attribs=True) class PairedSender(PairedMessageParticipant): connected: bool = False @classmethod def from_defaults(cls): return cls.from_tcp_ports(9001, 9002) def send_handshake(self, timeout_sec=30): """Send handshake until we get reply.""" wait_till = time.time() + timeout_sec while time.time() < wait_till: self._sender.send_dict(dict(type="handshake request")) reply = self._receiver.check_message() if self._is_handshake_reply(reply): return True else: # currently we drop replies until handshake is acknowledged pass time.sleep(0.1) return False def _is_handshake_reply(self, message: Any) -> bool: if message: return message.get("type", "") == "handshake reply" return False def send_dict(self, *args, **kwargs): self._sender.send_dict(*args, **kwargs) def send_array(self, *args, **kwargs): self._sender.send_array(*args, **kwargs) @attr.s(auto_attribs=True) class PairedReceiver(PairedMessageParticipant): connected: bool = False @classmethod def from_defaults(cls): return cls.from_tcp_ports(9002, 9001) def receive_handshake(self, timeout_sec=30): """Waits to receive and acknowledge handshake message.""" wait_till = time.time() + timeout_sec while time.time() < wait_till and not self.connected: message = self._receiver.check_message(fresh=True) if message is None: continue if self._is_handshake(message): self._respond_to_handshake() return True else: self._receiver.push_back_message(message) return True return False def _respond_to_handshake(self): self._sender.send_dict(dict(type="handshake reply")) self.connected = True def _is_handshake(self, message: Any): if message: return message.get("type", "") == "handshake request" return False def check_messages(self, ack_handshakes: bool = True, *args, **kwargs): """ Checks for messages. Args: ack_handshakes: If True, then any handshake messages are acknowledged and aren't included in return results Results: List of messages, possibly excluding any handshake requests. """ messages = self._receiver.check_messages(*args, **kwargs) if ack_handshakes: non_handshakes = [m for m in messages if not self._is_handshake(m)] if len(non_handshakes) < len(messages): self._respond_to_handshake() messages = non_handshakes return messages