diff --git a/umnp/test.py b/misc/test.py similarity index 59% rename from umnp/test.py rename to misc/test.py index 0a347134c444bcfd2a8ae3eca9aeb10991f9a408..9bc74730ee3f2d541066a0df40e145026b7206e1 100644 --- a/umnp/test.py +++ b/misc/test.py @@ -1,16 +1,15 @@ import binascii - -from protocol.data_message import DataMessage from umnp.microcontroller.umock.machine import SPI, Pin from umnp.microcontroller.umock.network import WIZNET5K -from umnp.protocol import Message +from umnp.protocol.message import Message +from umnp.protocol.messagetype import MessageType -x = DataMessage("abasdc", None, None) -w = WIZNET5K(SPI(), Pin(1), Pin(1), mac=b'') -print(w.config('mac')) +x = Message(MessageType.MSG_DEVICE_DATA, "abasdc", None) +w = WIZNET5K(SPI(), Pin(1), Pin(1), mac=b"") +print(w.config("mac")) w.config(mac=3) -print(w.config('mac')) +print(w.config("mac")) encoded = x.encode() print("====") print(binascii.hexlify(encoded)) diff --git a/programs/main_test.py b/programs/main_test.py index 84af29af4a5f6487ad6fd984636901dab0042cf3..c5d2e5e642ea26bbf25ed49ff8d60b9ae924cc26 100644 --- a/programs/main_test.py +++ b/programs/main_test.py @@ -1,15 +1,16 @@ import sys -from umnp.microcontroller.communication.communicator import Communicator +from umnp.microcontroller.communication.udp_communicator import UDPCommunicator +from umnp.microcontroller.devices.network.ethernet_w5500 import EthernetW5500 +from umnp.microcontroller.devices.network.udp import UDPSender, UDPReceiver from umnp.microcontroller.measurementdevice import MeasurementDevice -from umnp.microcontroller.network.ethernet_w5500 import EthernetW5500 -from umnp.microcontroller.network.udp import UDPSender, UDPReceiver from umnp.microcontroller.sensors.sht25 import SHT25 from umnp.microcontroller.tasks.periodictask import PeriodicTask if sys.implementation.name == "micropython": # noinspection PyUnresolvedReferences import machine + # noinspection PyUnresolvedReferences import uasyncio as asyncio else: @@ -27,7 +28,9 @@ def test_function(*args): def main(): # configure network device = MeasurementDevice() - spi = machine.SPI(0, 2_000_000, mosi=machine.Pin(19), miso=machine.Pin(16), sck=machine.Pin(18)) + spi = machine.SPI( + 0, 2_000_000, mosi=machine.Pin(19), miso=machine.Pin(16), sck=machine.Pin(18) + ) ether = EthernetW5500(spi, 17, 20, mac=device.generated_mac_raw(), dhcp=True) device.add_network_adapter(ether) @@ -36,7 +39,9 @@ def main(): sender = UDPSender(ether.ip, ether.netmask, 7777) receiver = UDPReceiver(ether.ip, 7776) - comm = Communicator(receiver=receiver, sender=sender, device_id=device.identifier) + comm = UDPCommunicator( + receiver=receiver, sender=sender, device_id=device.identifier + ) x = PeriodicTask(test_function, print, 1000) comm.add_task(x, "test_function") diff --git a/programs/sensor_calibration.py b/programs/sensor_calibration.py new file mode 100644 index 0000000000000000000000000000000000000000..9fcbb0e0bbb8a655b09aa50a8d5b44659418b9cc --- /dev/null +++ b/programs/sensor_calibration.py @@ -0,0 +1,51 @@ +import sys + +from umnp.microcontroller.communication.udp_communicator import UDPCommunicator +from umnp.microcontroller.devices.network.ethernet_w5500 import EthernetW5500 +from umnp.microcontroller.devices.network.udp import UDPSender, UDPReceiver +from umnp.microcontroller.measurementdevice import MeasurementDevice +from umnp.microcontroller.sensors.lps28dfw import LPS28DFW +from umnp.microcontroller.sensors.sht45 import SHT45 +from umnp.microcontroller.tasks.periodictask import PeriodicTask + +if sys.implementation.name == "micropython": + # noinspection PyUnresolvedReferences + import machine + + # noinspection PyUnresolvedReferences + import uasyncio as asyncio +else: + from umnp.microcontroller.umock import machine + import asyncio + + +def test_function(*args): + print("test_function called with: ", *args) + result = " - ".join(*args) + print(f"test_function returns '{result}'") + return result + + +def main(): + # configure network + device = MeasurementDevice() + spi = machine.SPI( + 0, 2_000_000, mosi=machine.Pin(19), miso=machine.Pin(16), sck=machine.Pin(18) + ) + ether = EthernetW5500(spi, 17, 20, mac=device.generated_mac_raw(), dhcp=True) + device.add_network_adapter(ether) + + i2c = machine.I2C(id=1, scl=machine.Pin(27), sda=machine.Pin(26)) + sht45 = SHT45(i2c) + p_sensor = LPS28DFW(i2c) + + sender = UDPSender(ether.ip, ether.netmask, 7777) + receiver = UDPReceiver(ether.ip, 7776) + comm = UDPCommunicator( + receiver=receiver, sender=sender, device_id=device.identifier + ) + + x = PeriodicTask(test_function, print, 1000) + comm.add_task(x, "test_function") + # start + asyncio.run(comm.start()) diff --git a/umnp/microcontroller/cpc/__init__.py b/programs/test.py similarity index 100% rename from umnp/microcontroller/cpc/__init__.py rename to programs/test.py diff --git a/scripts/test-ae33.py b/scripts/test-ae33.py new file mode 100644 index 0000000000000000000000000000000000000000..ce647273c3941252327315cca17a30cd2e832bd2 --- /dev/null +++ b/scripts/test-ae33.py @@ -0,0 +1,34 @@ +import datetime +import logging +import os +import time + +from umnp.communication.serial_connection import SerialConnection +from umnp.devices.aethalometer.ae33 import AE33 + +now = datetime.datetime.now() +if now.microsecond >= 500 * 1000: + now = now + datetime.timedelta(seconds=1) +now = now.replace(microsecond=0) +now_string = now.isoformat().replace(":", "-") +logger = logging.getLogger("ae33-test") +logger.setLevel(logging.DEBUG) +log_file_fn = os.path.join("logs", f"ae33-{now_string}.log") +log_file = logging.FileHandler(log_file_fn) +log_file.setLevel(logging.DEBUG) +logger.addHandler(log_file) + + +def main(): + conn_opts = {"baudrate": 115200, "dsrdtr": True, "rtscts": False, "xonoff": False} + connection = SerialConnection("/dev/ttyUSB2", options=conn_opts) + ae33 = AE33(connection) + while True: + current = ae33.current_measurement() + if current: + logging.info(current) + time.sleep(1) + + +if __name__ == "__main__": + main() diff --git a/umnp/communication/__init__.py b/umnp/communication/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6f27cff42a0dd982c6ebd8a22b6c55486d7949e3 --- /dev/null +++ b/umnp/communication/__init__.py @@ -0,0 +1,6 @@ +SERIAL_DEFAULT_BAUD_RATE = 112512 +SERIAL_DEFAULT_WAIT_TIME_S = 0.01 + + +class ConnectionProblemException(Exception): + pass diff --git a/umnp/communication/abstract_serial_connection.py b/umnp/communication/abstract_serial_connection.py new file mode 100644 index 0000000000000000000000000000000000000000..fbd55d138c348e8d98ea3dc70e08132b8d9cec28 --- /dev/null +++ b/umnp/communication/abstract_serial_connection.py @@ -0,0 +1,40 @@ +import asyncio + +from umnp.communication import SERIAL_DEFAULT_BAUD_RATE, SERIAL_DEFAULT_WAIT_TIME_S + + +class AbstractSerialConnection: + def __init__(self, address, options: dict | None = None): + self.__connection = None + self.__lock = asyncio.Lock() + self.__address = address + if options is None: + self.__options = {} + else: + self.__options = options + + self.__baud_rate = self.get_option("baudrate", SERIAL_DEFAULT_BAUD_RATE) + self.__wait_time = self.get_option("wait time", SERIAL_DEFAULT_WAIT_TIME_S) + self.__line_sep_read = self.get_option("read separator", b"\r") + self.__line_sep_write = self.get_option("write separator", b"\r") + self.__max_connection_attempts = self.get_option("max connection attempts", 3) + + def connect(self): + raise NotImplementedError + + def disconnect(self): + raise NotImplementedError + + def sync_command(self, command, expected_lines=None, show_reply=False): + raise NotImplementedError + + def send_command(self, command): + pass + + def read_line(self, timeout=None): + raise NotImplementedError + + def get_option(self, name: str, default=None): + if self.__options is None: + return default + return self.__options.get(name) diff --git a/umnp/communication/serial_connection.py b/umnp/communication/serial_connection.py new file mode 100644 index 0000000000000000000000000000000000000000..3c1834d665d32775d18654e9082f09c62ebd4ed7 --- /dev/null +++ b/umnp/communication/serial_connection.py @@ -0,0 +1,52 @@ +import logging +import time + +import serial + +from umnp.communication import ConnectionProblemException +from umnp.communication.abstract_serial_connection import AbstractSerialConnection + + +class SerialConnection(AbstractSerialConnection): + def __init__(self, address, options=None): + super().__init__(address, options) + + def connect(self): + if self.__connection: + return + attempts = 0 + connection = None + max_attempts = self.__max_connection_attempts + while attempts < max_attempts: + try: + connection = serial.Serial( + port=self.__address, + baudrate=self.__baud_rate, + parity=self.get_option("parity", serial.PARITY_NONE), + stopbits=self.get_option("stopbits", serial.STOPBITS_ONE), + timeout=self.__wait_time, + xonxoff=self.get_option("xonoff", False), + rtscts=self.get_option("rtscts", False), + dsrdtr=self.get_option("dsrdtr", False), + ) + except serial.serialutil.SerialException as e: + raise ConnectionProblemException(e) + + if connection: + break + + attempts += 1 + time.sleep(0.1) + + if not connection: + logging.error(f"could not connect to {self.__address}") + return + + time.sleep(0.1) + + self.__connection = connection + connection.flush() + connection.reset_input_buffer() + connection.reset_output_buffer() + logging.info(f"connected to {self.__address}, baud rate: {self.__baud_rate}") + logging.debug(connection) diff --git a/umnp/microcontroller/display/__init__.py b/umnp/devices/__init__.py similarity index 100% rename from umnp/microcontroller/display/__init__.py rename to umnp/devices/__init__.py diff --git a/umnp/microcontroller/eeprom/__init__.py b/umnp/devices/aethalometer/__init__.py similarity index 100% rename from umnp/microcontroller/eeprom/__init__.py rename to umnp/devices/aethalometer/__init__.py diff --git a/umnp/devices/aethalometer/ae33.py b/umnp/devices/aethalometer/ae33.py new file mode 100644 index 0000000000000000000000000000000000000000..d027b44f2ddbbb24b16db1dfe00fa389d0611afe --- /dev/null +++ b/umnp/devices/aethalometer/ae33.py @@ -0,0 +1,23 @@ +from umnp.communication.abstract_serial_connection import AbstractSerialConnection + + +class SerialConnection: + def __init__(self): + pass + + +class AE33: + def __init__(self, connection: AbstractSerialConnection): + self.__last_measurement = None + self.__connection = connection + pass + + def request_measurement(self) -> str | None: + return self.__connection.sync_command("$A33:D1\r", 1) + + def current_measurement(self) -> str | None: + current = self.request_measurement() + if current == self.__last_measurement: + return None + self.__last_measurement = current + return current diff --git a/umnp/microcontroller/display/lcd.py b/umnp/devices/cpc/__init__.py similarity index 100% rename from umnp/microcontroller/display/lcd.py rename to umnp/devices/cpc/__init__.py diff --git a/umnp/test_enum.py b/umnp/microcontroller/communication/serial_uart.py similarity index 100% rename from umnp/test_enum.py rename to umnp/microcontroller/communication/serial_uart.py diff --git a/umnp/microcontroller/communication/communicator.py b/umnp/microcontroller/communication/udp_communicator.py similarity index 86% rename from umnp/microcontroller/communication/communicator.py rename to umnp/microcontroller/communication/udp_communicator.py index dfa41f9cad4c12efc3320062cee0196c50df10fc..c099ef72830112f60f0e7d77c7cdf46e3540d592 100644 --- a/umnp/microcontroller/communication/communicator.py +++ b/umnp/microcontroller/communication/udp_communicator.py @@ -1,12 +1,13 @@ import sys import time -from umnp.microcontroller.network.udp import UDPSender, UDPReceiver +from umnp.microcontroller.devices.network.udp import UDPSender, UDPReceiver from umnp.microcontroller.tasks.periodictask import PeriodicTask if sys.implementation.name == "micropython": # noinspection PyUnresolvedReferences import uasyncio as asyncio + # noinspection PyUnresolvedReferences import machine else: @@ -14,9 +15,10 @@ else: from umnp.microcontroller.umock import machine -class Communicator: - def __init__(self, sender: UDPSender, receiver: UDPReceiver, device_id, max_msgs: int = 10): - +class UDPCommunicator: + def __init__( + self, sender: UDPSender, receiver: UDPReceiver, device_id, max_msgs: int = 10 + ): self._receive_lock = asyncio.Lock() self._send_lock = asyncio.Lock() @@ -59,7 +61,14 @@ class Communicator: if msg is not None: msg = msg.replace(",", ";") now = rtc.datetime() - now = "%04d-%02d-%02dT%02d:%02d:%02d" % (now[0], now[1], now[2], now[4], now[5], now[6]) + now = "%04d-%02d-%02dT%02d:%02d:%02d" % ( + now[0], + now[1], + now[2], + now[4], + now[5], + now[6], + ) await self._sender.broadcast("%s,%s,%s" % (device_id, now, msg)) await asyncio.sleep(0.5) @@ -97,4 +106,3 @@ class Communicator: def add_task(self, task: PeriodicTask, name: str): self._tasks[name] = task - diff --git a/umnp/microcontroller/devices/__init__.py b/umnp/microcontroller/devices/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/umnp/microcontroller/devices/display/__init__.py b/umnp/microcontroller/devices/display/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/umnp/microcontroller/devices/display/lcd.py b/umnp/microcontroller/devices/display/lcd.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/umnp/microcontroller/display/picolcd13.py b/umnp/microcontroller/devices/display/picolcd13.py similarity index 100% rename from umnp/microcontroller/display/picolcd13.py rename to umnp/microcontroller/devices/display/picolcd13.py diff --git a/umnp/microcontroller/eeprom/EEPROM_24LC32A.py b/umnp/microcontroller/devices/eeprom/EEPROM_24LC32A.py similarity index 100% rename from umnp/microcontroller/eeprom/EEPROM_24LC32A.py rename to umnp/microcontroller/devices/eeprom/EEPROM_24LC32A.py diff --git a/umnp/microcontroller/devices/eeprom/__init__.py b/umnp/microcontroller/devices/eeprom/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/umnp/microcontroller/network/__init__.py b/umnp/microcontroller/devices/network/__init__.py similarity index 100% rename from umnp/microcontroller/network/__init__.py rename to umnp/microcontroller/devices/network/__init__.py diff --git a/umnp/microcontroller/network/ethernet.py b/umnp/microcontroller/devices/network/ethernet.py similarity index 100% rename from umnp/microcontroller/network/ethernet.py rename to umnp/microcontroller/devices/network/ethernet.py diff --git a/umnp/microcontroller/network/ethernet_w5500.py b/umnp/microcontroller/devices/network/ethernet_w5500.py similarity index 89% rename from umnp/microcontroller/network/ethernet_w5500.py rename to umnp/microcontroller/devices/network/ethernet_w5500.py index 0b4c9691812717ab753388b915b93395e353fff6..43425a25d810b0c81a79f16ddb6d4b11c3f1a003 100644 --- a/umnp/microcontroller/network/ethernet_w5500.py +++ b/umnp/microcontroller/devices/network/ethernet_w5500.py @@ -1,9 +1,11 @@ -from umnp.microcontroller.network.ethernet import EthernetAdapter import sys +from umnp.microcontroller.devices.network.ethernet import EthernetAdapter + if sys.implementation.name == "micropython": # noinspection PyUnresolvedReferences import machine + # noinspection PyUnresolvedReferences import network else: @@ -16,7 +18,7 @@ class EthernetW5500(EthernetAdapter): self._spi = spi self._pin1 = machine.Pin(pin1) self._pin2 = machine.Pin(pin2) - self._nic = network.WIZNET5K(spi, self._pin1, self._pin2) + self._nic = network.WIZNET5K(spi, self._pin1, self._pin2, mac) # self._nic.active(True) self._nic.config(mac=mac) if dhcp: @@ -26,7 +28,7 @@ class EthernetW5500(EthernetAdapter): while True: print("Requesting IP via DHCP") try: - self._nic.ifconfig('dhcp') + self._nic.ifconfig("dhcp") return except OSError: pass diff --git a/umnp/microcontroller/network/udp.py b/umnp/microcontroller/devices/network/udp.py similarity index 85% rename from umnp/microcontroller/network/udp.py rename to umnp/microcontroller/devices/network/udp.py index 4f3c098844f8370e9f6332fb642e6ddf07648273..0b06fea160a828ffb5f7977d006ecb96da24996e 100644 --- a/umnp/microcontroller/network/udp.py +++ b/umnp/microcontroller/devices/network/udp.py @@ -1,8 +1,11 @@ -import sys -import socket import select +import socket +import sys -from umnp.microcontroller.network import LISTEN_TIMEOUT_MS, calculate_broadcast_ip +from umnp.microcontroller.devices.network import ( + LISTEN_TIMEOUT_MS, + calculate_broadcast_ip, +) if sys.implementation.name == "micropython": # noinspection PyUnresolvedReferences @@ -12,7 +15,9 @@ else: class UDPReceiver: - def __init__(self, listen_ip: str, listen_port: int, timeout: int = LISTEN_TIMEOUT_MS): + def __init__( + self, listen_ip: str, listen_port: int, timeout: int = LISTEN_TIMEOUT_MS + ): self.socket = None self._listen_port = listen_port self.listen_ip = listen_ip @@ -46,6 +51,6 @@ class UDPSender: async def broadcast(self, msg): # print("sending %s" % msg) if isinstance(msg, str): - msg = msg.encode('utf-8') + msg = msg.encode("utf-8") if self.socket: self.socket.sendto(msg, (self.broadcast_ip, self._target_port)) diff --git a/umnp/microcontroller/sensors/sht25/__init__.py b/umnp/microcontroller/sensors/sht25/__init__.py index c2304939f58542643c32b25f1b2246e0978ac5c8..88a9fbbea1832581cd4cdc07a50a6463fc7e5deb 100644 --- a/umnp/microcontroller/sensors/sht25/__init__.py +++ b/umnp/microcontroller/sensors/sht25/__init__.py @@ -3,8 +3,8 @@ try: except ImportError: from umnp.microcontroller.umock.machine import I2C -import time import asyncio +import time # SHT25_READ_T_HOLD = # SHT25_READ_RH_HOLD = @@ -44,7 +44,8 @@ class SHT25: def _command(self, command: int) -> int: pass - def _crc8(self, data): + @staticmethod + def _crc8(data): # CRC-8-Dallas/Maxim for I2C with 0x31 polynomial crc = 0x0 for byte in data: diff --git a/umnp/microcontroller/tasks/periodictask.py b/umnp/microcontroller/tasks/periodictask.py index deb9c32ed2377be7b5ee3df5eb2edf6c0627a610..ce2801fde2fbb40480a770a38395866991e1bff3 100644 --- a/umnp/microcontroller/tasks/periodictask.py +++ b/umnp/microcontroller/tasks/periodictask.py @@ -1,5 +1,6 @@ -import time import sys +import time + if sys.implementation.name == "micropython": # noinspection PyUnresolvedReferences import uasyncio as asyncio @@ -8,7 +9,9 @@ else: class PeriodicTask: - def __init__(self, function: callable, async_call_back: callable, every_ms: int, *args): + def __init__( + self, function: callable, async_call_back: callable, every_ms: int, *args + ): self._function = function self._args = args self._every_ms = every_ms @@ -18,14 +21,15 @@ class PeriodicTask: func = self._function args = self._args call_back = self._call_back - delay_seconds = 25/1000 + delay_seconds = 25 / 1000 while True: last = time.time() print("A0", args) print("A1", *args) result = await func(*args) print(result) - await call_back(result) + if call_back: + await call_back(result) while True: now = time.time() diff --git a/umnp/microcontroller/umock/framebuf/__init__.py b/umnp/microcontroller/umock/framebuf/__init__.py index fcea3fe449d5e1f3f5e14d56d3d5420203842e0c..a6219804f713105b4ea64dcfb4246ca66eaa6e0e 100644 --- a/umnp/microcontroller/umock/framebuf/__init__.py +++ b/umnp/microcontroller/umock/framebuf/__init__.py @@ -9,9 +9,8 @@ class FrameBuffer: self._format = fmt self._stride = stride if self._stride is None: - self._stride = width # or height? FIXME + self._stride = width # or height? FIXME self._fill = None def fill(self, *args): self._fill = args - diff --git a/umnp/microcontroller/umock/machine/__init__.py b/umnp/microcontroller/umock/machine/__init__.py index a8576b0fbf3538dac4d185b45f1669c07f5fbc9d..5bda3828b3e4f6bf7f12d9d65e8a8eec73bf73a4 100644 --- a/umnp/microcontroller/umock/machine/__init__.py +++ b/umnp/microcontroller/umock/machine/__init__.py @@ -1,7 +1,7 @@ import datetime import uuid -from umnp.protocol import MSG_BYTE_ORDER +from umnp.protocol.message_header import MSG_BYTE_ORDER def unique_id() -> bytes: diff --git a/umnp/protocol/__init__.py b/umnp/protocol/__init__.py index eb786fa7fbcd5e2ae293a686c11484c91166a999..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644 --- a/umnp/protocol/__init__.py +++ b/umnp/protocol/__init__.py @@ -1,7 +0,0 @@ -from umnp.protocol.constants import * -from umnp.protocol.messagetype import MSG_DEVICE_DATA - -from umnp.protocol.message import Message -from umnp.protocol.data_message import DataMessage - -Message.add_message_type(MSG_DEVICE_DATA, DataMessage) diff --git a/umnp/protocol/constants.py b/umnp/protocol/constants.py index e8ffbf06c7b6a9c1d9f600a897d0a049c10431b9..446b7ea017b67f35c423dbd25583a46f5beac793 100644 --- a/umnp/protocol/constants.py +++ b/umnp/protocol/constants.py @@ -4,12 +4,27 @@ try: except ImportError: pass -MSG_STRING_ENCODING: str = 'utf-8' +MSG_STRING_ENCODING: str = "utf-8" MSG_PROTOCOL_VERSION: int = 0 -MSG_BYTE_ORDER: typing.Literal["little", "big"] = 'big' -MSG_VERSION_LENGTH: int = 2 -MSG_TYPE_LENGTH: int = 4 -MSG_PAYLOAD_LENGTH: int = 4 -MSG_TIMESTAMP_LENGTH: int = 4 -MSG_SENDER_ID_LENGTH: int = 6 -MSG_SENDER_TYPE_LENGTH: int = 2 + +MSG_BYTE_ORDER: typing.Literal["little", "big"] = "big" + +# number of bytes reserved for fields in the message header + +# protocol version +MSG_LEN_PROTOCOL_VERSION: int = 2 + +# message type, see unmp.protocol.messagetype +MSG_LEN_TYPE: int = 2 + +# message timestamp (of sending), seconds since epoch for version 0 +MSG_LEN_TIMESTAMP: int = 4 + +# sender unique id (e.g. serial number) +MSG_LEN_SENDER_ID: int = 6 + +# sender type (e.g. model of measurement or control device attached to microcontroller) +MSG_LEN_SENDER_TYPE: int = 2 + +# length of payload +MSG_LEN_PAYLOAD_SIZE: int = 2 diff --git a/umnp/protocol/data_message.py b/umnp/protocol/data_message.py index e1f77617b3f29aa5d180bcecfe8690f71beb1925..b59994a116af0c9acba8a09d5fd706d55b03eb89 100644 --- a/umnp/protocol/data_message.py +++ b/umnp/protocol/data_message.py @@ -1,14 +1,22 @@ -from umnp.protocol import MSG_DEVICE_DATA +from umnp.protocol.common.timestamp import TimeStamp from umnp.protocol.constants import MSG_STRING_ENCODING from umnp.protocol.message import Message, MessageHeader -from umnp.protocol.common.timestamp import TimeStamp +from umnp.protocol.messagetype import MessageType class DataMessage(Message): - def __init__(self, data: str, sender_id: bytes, sender_type: int, send_time: TimeStamp = None): + def __init__( + self, data: str, sender_id: bytes, sender_type: int, send_time: TimeStamp = None + ): self._payload = data self._encoded_data = data.encode(MSG_STRING_ENCODING) - super().__init__(MSG_DEVICE_DATA, self._encoded_data, sender_id, sender_type, send_time) + super().__init__( + MessageType.MSG_DEVICE_DATA, + self._encoded_data, + sender_id, + sender_type, + send_time, + ) @staticmethod def _decode_payload(transferred_data): @@ -18,9 +26,8 @@ class DataMessage(Message): return self._payload @classmethod - def decode(cls, payload: bytes, header: MessageHeader) -> 'DataMessage': + def decode(cls, payload: bytes, header: MessageHeader) -> "DataMessage": decoded_payload = cls._decode_payload(payload) - return cls(decoded_payload, header.sender_id, header.sender_type, header.timestamp) - - - + return cls( + decoded_payload, header.sender_id, header.sender_type, header.timestamp + ) diff --git a/umnp/protocol/message.py b/umnp/protocol/message.py index 0417874e08d6f8a4c6db8ff43f98c7440498322e..b3d5822958ed131fa20cafa4712ff120ebed36d0 100644 --- a/umnp/protocol/message.py +++ b/umnp/protocol/message.py @@ -1,4 +1,3 @@ - try: # noinspection PyUnresolvedReferences import typing @@ -11,9 +10,9 @@ except ImportError: import umnp.protocol.compat.logging as logging -from umnp.protocol.constants import MSG_BYTE_ORDER, MSG_PAYLOAD_LENGTH -from umnp.protocol.message_header import MessageHeader from umnp.protocol.common.timestamp import TimeStamp +from umnp.protocol.constants import MSG_BYTE_ORDER, MSG_LEN_PAYLOAD_SIZE +from umnp.protocol.message_header import MessageHeader class Message: @@ -23,12 +22,14 @@ class Message: def message_types(self): return self._registered_types - def __init__(self, - msg_type: int, - data: bytes, - sender_id: bytes, - sender_type: int, - send_time: typing.Optional[TimeStamp]): + def __init__( + self, + msg_type: int, + data: bytes, + sender_id: bytes, + sender_type: int, + send_time: typing.Optional[TimeStamp], + ): """ Parameters ---------- @@ -38,14 +39,15 @@ class Message: self._sender_id = sender_id self._sender_type = sender_type self._data = data - self._header = MessageHeader(msg_type, - sender_device_type=sender_type, - sender_device_id=sender_id, - send_time=send_time) - + self._header = MessageHeader( + msg_type, + sender_device_type=sender_type, + sender_device_id=sender_id, + send_time=send_time, + ) @classmethod - def add_message_type(cls, msg_type: int, msg: typing.Type['Message']): + def add_message_type(cls, msg_type: int, msg: typing.Type["Message"]): if msg_type in cls._registered_types: logging.info(f"Already registered {msg_type}") return @@ -68,11 +70,15 @@ class Message: return self._data def encode(self) -> bytes: - header = MessageHeader(msg_type=self.type, - sender_device_type=self._sender_type, - sender_device_id=self._sender_id) + header = MessageHeader( + msg_type=self.type, + sender_device_type=self._sender_type, + sender_device_id=self._sender_id, + ) payload = self.data - payload_length = len(payload).to_bytes(length=MSG_PAYLOAD_LENGTH, byteorder=MSG_BYTE_ORDER) + payload_length = len(payload).to_bytes( + length=MSG_LEN_PAYLOAD_SIZE, byteorder=MSG_BYTE_ORDER + ) return header.encode() + payload_length + payload @classmethod @@ -80,23 +86,24 @@ class Message: header = MessageHeader.decode(data) offset = header.encoded_size_bytes try: - payload_bytes = data[offset:offset + MSG_PAYLOAD_LENGTH] + payload_bytes = data[offset : offset + MSG_LEN_PAYLOAD_SIZE] payload_length = int.from_bytes(payload_bytes, MSG_BYTE_ORDER) except (KeyError, ValueError): logging.error("Invalid message: could not extract payload length") return None - offset += MSG_PAYLOAD_LENGTH + offset += MSG_LEN_PAYLOAD_SIZE payload = data[offset:] if len(payload) != payload_length: - logging.error("Invalid message: mismatch between specified and actual payload length") + logging.error( + "Invalid message: mismatch between specified and actual payload length" + ) return None - msg = Message.get_message_type(header.message_type).decode(payload, header=header) + msg = Message.get_message_type(header.message_type).decode( + payload, header=header + ) return msg - - - def __str__(self): return f"Message of type {self.type}" diff --git a/umnp/protocol/message_header.py b/umnp/protocol/message_header.py index c0de102534408c1c737f6f148d558f0200366985..2bed493c2e10f2f0c89214cbfb0c4d7f87f54806 100644 --- a/umnp/protocol/message_header.py +++ b/umnp/protocol/message_header.py @@ -1,6 +1,13 @@ from umnp.protocol.common.timestamp import TimeStamp, valid_timestamp -from umnp.protocol.constants import MSG_PROTOCOL_VERSION, MSG_VERSION_LENGTH, MSG_BYTE_ORDER, MSG_TYPE_LENGTH, \ - MSG_TIMESTAMP_LENGTH, MSG_SENDER_ID_LENGTH, MSG_SENDER_TYPE_LENGTH +from umnp.protocol.constants import ( + MSG_PROTOCOL_VERSION, + MSG_LEN_PROTOCOL_VERSION, + MSG_BYTE_ORDER, + MSG_LEN_TYPE, + MSG_LEN_TIMESTAMP, + MSG_LEN_SENDER_ID, + MSG_LEN_SENDER_TYPE, +) try: import logging @@ -15,53 +22,78 @@ except ImportError: class MessageHeader: - def __init__(self, - msg_type: int, - sender_device_id: bytes, - sender_device_type: int, - send_time: typing.Optional[TimeStamp] = None, - version: int = MSG_PROTOCOL_VERSION, - ): - + def __init__( + self, + msg_type: int, + sender_device_id: bytes, + sender_device_type: int, + send_time: typing.Optional[TimeStamp] = None, + version: int = MSG_PROTOCOL_VERSION, + ): self._message_type = msg_type self._send_time = valid_timestamp(send_time) self._timestamp = TimeStamp(self._send_time) self._version = version self._sender_device_id = sender_device_id - self._sender_device_type = sender_device_type + self._sender_dev_type = sender_device_type + + @property + def timestamp_encoded(self) -> bytes: + return self._timestamp.value.to_bytes(MSG_LEN_TIMESTAMP, MSG_BYTE_ORDER) + + @property + def sender_dev_id(self) -> bytes: + return self._sender_device_id @property def message_type(self) -> int: return self._message_type @property - def version(self): + def message_type_encoded(self) -> bytes: + return self.message_type.to_bytes(MSG_LEN_TYPE, MSG_BYTE_ORDER) + + @property + def version(self) -> int: return self._version + @property + def version_encoded(self) -> bytes: + return self._version.to_bytes(MSG_LEN_PROTOCOL_VERSION, MSG_BYTE_ORDER) + @property def sender_id(self) -> bytes: return self._sender_device_id @property def sender_type(self) -> int: - return self._sender_device_type + return self._sender_dev_type + + @property + def sender_dev_type_encoded(self) -> bytes: + return self._sender_dev_type.to_bytes(MSG_LEN_SENDER_TYPE, MSG_BYTE_ORDER) @property def timestamp(self) -> TimeStamp: return self._timestamp def encode(self) -> bytes: - version = self.version.to_bytes(length=MSG_VERSION_LENGTH, byteorder=MSG_BYTE_ORDER) + version = self.version_encoded sender_id = self._sender_device_id - sender_type = self._sender_device_type.to_bytes(length=MSG_SENDER_TYPE_LENGTH, byteorder=MSG_BYTE_ORDER) - message_type = self.message_type.to_bytes(length=MSG_TYPE_LENGTH, byteorder=MSG_BYTE_ORDER) - timestamp = self._timestamp.value.to_bytes(length=MSG_TIMESTAMP_LENGTH, byteorder=MSG_BYTE_ORDER) + sender_type = self.sender_dev_type_encoded + message_type = self.message_type_encoded + timestamp = self.timestamp_encoded return version + sender_id + sender_type + message_type + timestamp @property def encoded_size_bytes(self) -> int: - return (MSG_VERSION_LENGTH + MSG_SENDER_ID_LENGTH + MSG_SENDER_TYPE_LENGTH + - MSG_TYPE_LENGTH + MSG_TIMESTAMP_LENGTH) + return ( + MSG_LEN_PROTOCOL_VERSION + + MSG_LEN_SENDER_ID + + MSG_LEN_SENDER_TYPE + + MSG_LEN_TYPE + + MSG_LEN_TIMESTAMP + ) @classmethod def decode(cls, data: bytes): @@ -71,48 +103,50 @@ class MessageHeader: logging.error("Invalid message header: not bytes") return None try: - protocol_bytes = data[:MSG_VERSION_LENGTH] + protocol_bytes = data[:MSG_LEN_PROTOCOL_VERSION] protocol_version = int.from_bytes(protocol_bytes, MSG_BYTE_ORDER) cls._version = protocol_version except (KeyError, ValueError): logging.error("Invalid message header: could not extract version") return None - offset += MSG_VERSION_LENGTH + offset += MSG_LEN_PROTOCOL_VERSION try: - message_sender_id_bytes = data[offset:offset + MSG_SENDER_ID_LENGTH] - message_sender_type_bytes = data[ - offset + MSG_SENDER_ID_LENGTH: offset + MSG_SENDER_ID_LENGTH + MSG_SENDER_TYPE_LENGTH] - message_sender_type = int.from_bytes(message_sender_type_bytes, MSG_BYTE_ORDER) - except(KeyError, ValueError): + msg_sender_id_b = data[offset : (offset + MSG_LEN_SENDER_ID)] + offset += MSG_LEN_SENDER_ID + msg_sender_type_b = data[offset : offset + MSG_LEN_SENDER_TYPE] + offset += MSG_LEN_SENDER_TYPE + msg_sender_type = int.from_bytes(msg_sender_type_b, MSG_BYTE_ORDER) + except (KeyError, ValueError): logging.error("Invalid message sender information: could not extract data") return None - offset += MSG_SENDER_TYPE_LENGTH + MSG_SENDER_ID_LENGTH - try: - message_type_bytes = data[offset:offset + MSG_TYPE_LENGTH] + message_type_bytes = data[offset : offset + MSG_LEN_TYPE] message_type = int.from_bytes(message_type_bytes, MSG_BYTE_ORDER) except (KeyError, ValueError): logging.error("Invalid message payload: could not extract version") return None - offset += MSG_TYPE_LENGTH + offset += MSG_LEN_TYPE if protocol_version < 0 or protocol_version > MSG_PROTOCOL_VERSION: - logging.error(f"Invalid protocol version {protocol_version}, outside of range [0, {MSG_PROTOCOL_VERSION}]") + err = f"Invalid protocol version {protocol_version}, outside of range [0, {MSG_PROTOCOL_VERSION}]" + logging.error(err) return None try: - message_ts = data[offset: offset + MSG_TIMESTAMP_LENGTH] + message_ts = data[offset : offset + MSG_LEN_TIMESTAMP] timestamp = TimeStamp(int.from_bytes(message_ts, byteorder=MSG_BYTE_ORDER)) except (KeyError, ValueError): logging.error("Invalid message timestamp: could not extract version") return None - return cls(message_type, - sender_device_type=message_sender_type, - sender_device_id=message_sender_id_bytes, - version=protocol_version, - send_time=timestamp) + return cls( + message_type, + sender_device_type=msg_sender_type, + sender_device_id=msg_sender_id_b, + version=protocol_version, + send_time=timestamp, + ) diff --git a/umnp/protocol/messagetype.py b/umnp/protocol/messagetype.py index a27bcbc19dbbf06425a10ece88c673f0d8858bb4..bd4b1caf9e8b4b9ac6f673de347cc456a8dea929 100644 --- a/umnp/protocol/messagetype.py +++ b/umnp/protocol/messagetype.py @@ -1,3 +1,12 @@ # This would be an enum, if micropython supported them -MSG_DEVICE_DATA = 1 + +class MessageType: + MSG_DEVICE_DATA = 1 + MSG_TYPE_UNKNOWN = 65535 + + _allowed_message_types = [MSG_DEVICE_DATA, MSG_TYPE_UNKNOWN] + + @staticmethod + def valid_message_type(msg_type: int): + return msg_type in MessageType._allowed_message_types