# -*- coding: utf-8 -*- """ ------------------------------------------------------------------------------ Mango 802.11 Reference Design Experiments Framework - Messages ------------------------------------------------------------------------------ License: Copyright 2019 Mango Communications, Inc. All rights reserved. Use and distribution subject to terms in LICENSE.txt ------------------------------------------------------------------------------ This module provides class definitions for all over-the-wire messages. Functions (see below for more information): TransportHeader() -- Transport Header to describe message contents Cmd() -- Base class for commands the require responses BufferCmd() -- Base class for commands that require buffer responses Resp() -- Base class for responses (single packet) Buffer() -- Base class for responses (multiple packets) """ import struct from . import transport __all__ = ['TransportHeader', 'Cmd', 'BufferCmd', 'Resp', 'Buffer'] # Buffer Command defines CMD_BUFFER_GET_SIZE_FROM_DATA = 0xFFFFFFFF class Message(object): """Base class for messages to / from 802.11 ref design nodes. Attributes: const -- Dictionary of constants """ const = None def __init__(self): self.const = dict() def add_const(self, name, value): """Add a constant to the Message.""" if (name in self.const.keys()): print("WARNING: Updating value of constant: {0}".format(name)) self.const[name] = value def get_const(self, name): """Get a constant from the Message.""" if (name in self.const.keys()): return self.const[name] else: msg = "Constant {0} does not exist ".format(name) msg += "in {0}".format(self.__class__.__name__) raise AttributeError(msg) def serialize(self,): raise NotImplementedError def deserialize(self, data_buffer): raise NotImplementedError def sizeof(self,): raise NotImplementedError # End Class class TransportHeader(Message): """Class for Transport Header describing message contents. Attributes: dest_id -- (uint16) Destination ID of the message src_id -- (uint16) Source ID of the message length -- (uint16) Length of the payload in bytes seq_num -- (uint16) Sequence number of the message flags -- (uint16) Flags of the message """ def __init__(self, dest_id=0, src_id=0, reserved=0, length=0, seq_num=0, flags=0): super(TransportHeader, self).__init__() self.dest_id = dest_id self.src_id = src_id self.reserved = reserved self.length = length self.seq_num = seq_num self.flags = flags def serialize(self): """Return a bytes object of a packed transport header.""" return struct.pack('!5H', self.dest_id, self.src_id, self.length, self.seq_num, self.flags) def deserialize(self, data_buffer): """Not used for Transport headers""" pass def sizeof(self): """Return the size of the transport header.""" return struct.calcsize('!5H') def increment(self, step=1): """Increment the sequence number of the header by a given step.""" self.seq_num = (self.seq_num + step) % 0xFFFF def set_length(self, value): self.length = value def set_src_id(self, value): self.src_id = value def set_dest_id(self, value): self.dest_id = value def get_length(self): return self.length def get_src_id(self): return self.src_id def get_dest_id(self): return self.dest_id def response_required(self): """Sets bit 0 of the flags since a response is required.""" self.flags = self.flags | 0x1 def response_not_required(self): """Clears bit 0 of the flags since a response is not required.""" self.flags = self.flags & 0xFFFE def reset(self): """Reset the sequence number of the transport header.""" self.seq_num = 1 def is_reply(self, input_data): """Checks input_data to see if it is a valid reply to the last outgoing packet. Checks: input_data.dest_id == self.src_id input_data.src_id == self.dest_id input_data.seq_num == self.seq_num Raises a TypeError excpetion if input data is not the correct size. """ if len(input_data) == self.sizeof(): # FIXME: embedded TransportHeader format, here instead of deserialize dataTuple = struct.unpack('!5H', input_data[0:10]) if ((self.dest_id != dataTuple[1]) or (self.src_id != dataTuple[0]) or (self.seq_num != dataTuple[3])): msg = "WARNING: transport header mismatch:" msg += "[{0:d} {1:d}]".format(self.dest_id, dataTuple[1]) msg += "[{0:d} {1:d}]".format(self.src_id, dataTuple[0]) msg += "[{0:d} {1:d}]".format(self.seq_num, dataTuple[3]) print(msg) return False else: return True else: raise TypeError(str("TransportHeader: length of header " + "did not match size of transport header")) # End Class class CmdRespMessage(Message): """Base class for command / response messages. Attributes: command -- (uint32) ID of the command / response length -- (uint16) Length of the cmd / resp arguments (args) in bytes num_args -- (uint16) Number of uint32 arguments args -- (list of uint32) Arguments of the command / reponse """ command = None length = None num_args = None args = None payload = None raw_data = None def __init__(self, command=0, length=0, num_args=0, args=None, payload=None): super(CmdRespMessage, self).__init__() self.command = command self.length = length self.num_args = num_args self.args = args or [] self.payload = payload or [] self.raw_data = bytearray(self.serialize()) def serialize(self): """Return a bytes object of a packed command / response.""" ret_val = b'' # Add the command header ret_val += struct.pack('!I 2H', self.command, self.length, self.num_args) # Add the command arguments, if any if self.num_args > 0: ret_val += struct.pack('!{0}I'.format(self.num_args), *self.args) # Add the command payload, if any if self.payload: ret_val += self.payload self.raw_data = bytearray(ret_val) return ret_val def deserialize(self, data_buffer): """Populate the fields of a CmdRespMessage from a buffer.""" try: dataTuple = struct.unpack('!I 2H', data_buffer[0:8]) self.command = dataTuple[0] self.length = dataTuple[1] self.num_args = dataTuple[2] self.args = list(struct.unpack_from('!%dI' % self.num_args, data_buffer, offset=8)) self.raw_data = bytearray(data_buffer) # Store any raw payload following the response arguments # Payload (if any) starts at first byte after last resp argument payload_offset = 8 + 4*self.num_args if payload_offset < len(self.raw_data): self.resp_payload = self.raw_data[payload_offset:] else: self.resp_payload = bytearray() except struct.error as err: # Reset Cmd/Resp. We want predictable behavior on error self.reset() print("Error unpacking cmd/resp: {0}".format(err)) def get_bytes(self): """Returns the data buffer as bytes.""" return self.raw_data def sizeof(self): """Return the size of the cmd/resp including all attributes.""" if self.num_args == 0: return struct.calcsize('!I 2H') else: return struct.calcsize('!I 2H %dI' % self.num_args) def reset(self): """Reset the CmdRespMessage object to a default state (all zeros)""" self.command = 0 self.length = 0 self.num_args = 0 self.args = [] # End Class class Cmd(CmdRespMessage): """Base class for commands. Attributes: resp_type -- Response type of the command. See transport documentation for defined response types. By default, a Cmd will require a Resp as a response. See documentation of CmdRespMessage for additional attributes """ resp_type = None def __init__(self, command=0, length=0, num_args=0, args=None, resp_type=None): super(Cmd, self).__init__(command, length, num_args, args) self.resp_type = resp_type or transport.TRANSPORT_RESP def set_args(self, *args): """Set the command arguments.""" self.args = args self.num_args = len(args) self.length = self.num_args * 4 def add_args(self, *args): """Append arguments to current command argument list. Since the transport only operates on unsigned 32 bit integers, the command will convert all values to 32 bit unsigned integers. """ import ctypes for arg in args: if (arg < 0): arg = ctypes.c_uint32(arg).value self.args.append(arg) self.num_args += len(args) self.length += len(args) * 4 def set_payload(self, payload): """Adds an arbitrary payload to the message, typically a bytearray. The payload is appended after any arguments when the message is serialized for transmission over-the-wire. """ # This method sets the entire payload, possibly replacing an existing payload # First subtract the length of the existing payload, then set the new payload # and length values if self.payload is not None: self.length -= len(self.payload) self.payload = payload self.length += len(self.payload) def get_resp_type(self): return self.resp_type def process_resp(self, resp): """Process the response of the command.""" raise NotImplementedError def __str__(self): """Pretty print the Command""" msg = "" if self.command is not None: msg += "Command [{0:d}] ".format(self.command) msg += "({0:d} bytes): \n".format(self.length) if (self.num_args > 0): msg += " Args [0:{0:d}] : \n".format(self.num_args) for i in range(len(self.args)): msg += " 0x{0:08x}\n".format(self.args[i]) return msg # End Class class BufferCmd(CmdRespMessage): """Base Class for Buffer commands. Arguments: start_byte -- (uint32) Starting address of the buffer for this message. size -- (uint32) Size of the buffer in bytes - Reserved value: CMD_BUFFER_GET_SIZE_FROM_DATA Attributes: resp_type -- Response type of the command. See transport documentation for defined response types. By default, a BufferCmd will require a Buffer as a repsonse. Note: The wire format of a Buffer command is: Word[0] -- start_address of transfer Word[1] -- size of transfer (in bytes) Word[2 - N] -- Additional arguments To add additional arguments to a BufferCmd, use the add_args() method. See documentation of CmdRespMessage for additional attributes """ resp_type = None start_byte = None size = None def __init__(self, command=0, start_byte=0, size=0): super(BufferCmd, self).__init__(command=command, length=16, num_args=2, args=[start_byte, size]) self.resp_type = transport.TRANSPORT_BUFFER self.start_byte = start_byte self.timestamp_in_hdr = False if (size == CMD_BUFFER_GET_SIZE_FROM_DATA): self.size = 0 else: self.size = size def get_resp_type(self): return self.resp_type def get_buffer_start_byte(self): return self.start_byte def get_buffer_size(self): return self.size def update_start_byte(self, value): self.start_byte = value self.args[0] = value # FIXME: why self.args[0] not self.start_byte? def update_size(self, value): self.size = value self.args[1] = value # FIXME: why self.args[1] not self.size? def add_args(self, *args): """Append arguments to current command argument list. Since the transport only operates on unsigned 32 bit integers, the command will convert all values to 32 bit unsigned integers. """ import ctypes for arg in args: if (arg < 0): arg = ctypes.c_uint32(arg).value self.args.append(arg) self.num_args += len(args) self.length += len(args) * 4 def process_resp(self, resp): """Process the response of the command.""" raise NotImplementedError def __str__(self): """Pretty print the Command""" msg = "" if self.command is not None: msg += "Buffer Command [{0:d}] ".format(self.command) msg += "({0:d} bytes): \n".format(self.length) if (self.num_args > 0): msg += " Args [0:{0:d}] : \n".format(self.num_args) for i in range(len(self.args)): msg += " 0x{0:08x}\n".format(self.args[i]) return msg # End Class class Resp(CmdRespMessage): """Class for responses. This class is used if a command will return at most a single Ethernet packet of data from the node. See documentation of CmdRespMessage for attributes """ def get_args(self): """Return the response arguments.""" return self.args def resp_is_valid(self, num_args=None, status_errors=None, name=None): """Checks if the response is valid Args: num_args (int, optional): Check that the number of arguments matches the provided value status_errors (dict, optional): Dictionary of status errors for the form: { : } name (str, optional): Name of the command Returns: is_valid (bool): Is the response valid? """ error = False msg = "ERROR:\n" if num_args is not None: if len(self.args) != num_args: msg += " Number of arguments in response ({0}) does not match\n".format(len(self.args)) msg += " number of expected arguments ({0})\n".format(num_args) error = True if status_errors is not None: if self.args: if (self.args[0] in status_errors.keys()): if name is not None: msg += " Received status error {0}\n".format(name) else: msg += " Received status error\n" msg += " {0}\n".format(status_errors[self.args[0]]) error = True else: if name is not None: msg += " No status in response {0}\n".format(name) else: msg += " No status in response\n" error = True if error: print(msg) print(self) return False else: return True def __str__(self): """Pretty print the Response""" msg = "" if self.command is not None: msg += "Response [CMDID = 0x{0:08x}] ".format(self.command) msg += "({0:d} bytes): \n".format(self.length) if (self.num_args > 0): msg += " Args [0:{0:d}] : \n".format(self.num_args) for i in range(len(self.args)): msg += " 0x{0:08x}\n".format(self.args[i]) return msg # End Class class Buffer(Message): """Class for buffer for transferring generic information. This class is used if a command will return one or more Ethernet packets of data from the node. This object provides a container to transfer information that will be decoded by higher level functions. Attributes: complete -- Flag to indicate if buffer contains all of the bytes indicated by the size parameter start_byte -- Start byte of the buffer num_bytes -- Number of bytes currently contained within the buffer Wire Data Format: command -- (uint32) command / response ID length -- (uint16) Length of the cmd / resp args in bytes num_args -- (uint16) Number of uint32 arguments bytes_remainig -- (uint32) Number of bytes remain in the current request start_byte -- (uint32) Address of start byte for the transfer size -- (uint32) Size of the buffer in bytes buffer -- (list of uint8) Content of the buffer """ complete = None start_byte = None num_bytes = None tracker = None size = None buffer = None def __init__(self, start_byte=0, size=0, data_buffer=None): self.start_byte = start_byte self.size = size self.timestamp_in_hdr = False self.retrieval_timestamp = None self.tracker = [{0:start_byte, 1:start_byte, 2:0}] if data_buffer is not None: self._add_buffer_data(start_byte, data_buffer) else: # Create an empty buffer of the specified size self.complete = False self.num_bytes = 0 self.buffer = bytearray(self.size) def serialize(self): """Return a bytes object of a packed buffer.""" return self.serialize_cmd() def serialize_cmd(self, command=None, start_byte=None): """Return a bytes object of a packed buffer.""" if command is None: command = 0 if start_byte is None: start_byte = self.start_byte # FIXME: who calls this method? And why does this header # use the response header format (with 5I) instead of # the command header format (with 4I)? # FIXME 2: Updated format to remove buffer_id and flags, still # don't know who calls this method to construct a response header? return struct.pack('!I 2H 3I %dB' % self.size, command, 20, 3, # length = Num_args * 4 bytes / arg; Num_args = 3; 0, start_byte, self.size, *self.buffer) def deserialize(self, data_buffer): """Populate the fields of a Buffer with a message raw_data.""" (args, data) = self._unpack_data(data_buffer) bytes_remaining = args[0] start_byte = args[1] offset = (start_byte - self.start_byte) self._update_buffer_size(bytes_remaining) self._add_buffer_data(offset, data) self._set_buffer_complete() def add_data_to_buffer(self, raw_data): """Add the raw data (with the format of a Buffer) to the current Buffer. Note: This will check to make sure that data is for the given buffer as well as place it in the appropriate place indicated by the start_byte. """ (args, data) = self._unpack_data(raw_data) bytes_remaining = args[0] start_byte = args[1] # Old code checked args.buffer_id == self.buffer_id # This check never failed in old code # This is where new code could verify req_id? offset = (start_byte - self.start_byte) self._update_buffer_size(bytes_remaining) self._add_buffer_data(offset, data) self._set_buffer_complete() def append(self, buffer): """Append the contents of the provided Buffer to the current Buffer.""" curr_size = self.size new_size = curr_size + buffer.get_buffer_size() self._update_buffer_size(new_size, force=1) self._add_buffer_data(curr_size, buffer.get_bytes()) self._set_buffer_complete() def merge(self, buffer): """Merge the contents of the provided Buffer to the current Buffer.""" start_byte = buffer.get_start_byte() offset = (start_byte - self.start_byte) size = buffer.get_buffer_size() end_byte = offset + size if (end_byte > self.size): # Need to update the buffer to allocate more memory first self._update_buffer_size(end_byte, force=1) self._add_buffer_data(offset, buffer.get_bytes()) self._set_buffer_complete() def trim(self): """Trim the Buffer to the largest contiguous number of bytes received.""" locations = self.get_missing_byte_locations() if locations: # This assumes that the missing byte locations are in order # with the first missing byte after the start byte in the first # item in the list. missing_start = locations[0][0] contiguous_size = missing_start - self.start_byte self.num_bytes = contiguous_size self._update_buffer_size(contiguous_size, force=1) self._set_buffer_complete() def sizeof(self): """Return the size of the Buffer including all attributes.""" # Do not calculate the size of the buffer just using calcsize. # This is extremely memory inefficient for large buffers and can # cause memory issues. # 3I is format of list/log retrieval header (3 u32 args) return (struct.calcsize('!3I') + self.size) def get_start_byte(self): return self.start_byte def get_header_size(self): return struct.calcsize('!3I') def get_buffer_size(self): return self.size def get_occupancy(self): return self.num_bytes def set_bytes(self, data_buffer): """Set the message bytes of the Buffer.""" self._update_buffer_size(len(data_buffer), force=1) self._add_buffer_data(0, data_buffer) self._set_buffer_complete() def get_bytes(self): """Return the message bytes of the Buffer.""" return self.buffer def get_missing_byte_locations(self): """Returns a list of tuples (start_index, end_index, size) that contain the missing byte locations. """ if not self.complete: return self._find_missing_bytes() else: return [] def is_buffer_complete(self): """Return if the Buffer is complete.""" return self.complete def reset(self): """Reset the Buffer object to a default state (all zeros)""" self.buffer_id = 0 self.flags = 0 self.size = 0 self.buffer = bytearray(self.size) def __str__(self): """Pretty print the Buffer""" msg = "" if self.buffer is not None: msg += "Buffer [{0:d}] ".format(self.buffer_id) msg += "({0:d} bytes): \n".format(self.size) msg += " Flags : 0x{0:08x} \n".format(self.flags) msg += " Start : {0:d}\n".format(self.start_byte) msg += " Num bytes: {0:d}\n".format(self.num_bytes) msg += " Complete : {0}\n".format(self.complete) if (False): msg += " Data : " for i in range(len(self.buffer)): if (i % 16) == 0: msg += "\n {0:02x} ".format(self.buffer[i]) else: msg += "{0:02x} ".format(self.buffer[i]) return msg # ------------------------------------------------------------------------- # Internal helper methods # ------------------------------------------------------------------------- def _unpack_data(self, raw_data): """Internal method to unpack a data buffer.""" args = [] data = [] if 1: # First unpack just the cmd_resp header cmd_resp_hdr = CmdRespMessage() cmd_resp_hdr.deserialize(raw_data) args = cmd_resp_hdr.args # List and log retrieval responses set num_args=0 when returning # an empty list/payload if args: if self.timestamp_in_hdr: # List retrieval responses include timestamp list_ret_args = struct.unpack('!3I Q', raw_data[8:28]) size = args[2] self.retrieval_timestamp = args[3] data = struct.unpack_from('!%dB' % size, raw_data, offset=28) args += list_ret_args else: # Log retrieval response does not include timestamp log_ret_args = struct.unpack('!3I', raw_data[8:20]) size = args[2] data = struct.unpack_from('!%dB' % size, raw_data, offset=20) args += log_ret_args else: # Node returned empty list/log buffer # Construct "empty" response args to return # 0: bytes_remaining # 1: start_byte # 2: size args = (0, 0, 0) data = [] #except struct.error as err: # # Ignore the data. We want predictable behavior on error # print("Error unpacking buffer:") # print(" {0}".format(err)) # print(" Ignorning data. This error could be caused by a mismatch between\n") # print(" the maximum packet size between the node and the host.\n") return (args, data) def _update_buffer_size(self, size, force=0): """Internal method to update the size of the transfer.""" if (self.size == 0): self.size = size self.buffer = bytearray(self.size) elif (force == 1): # Update the size of the buffer old_size = self.size self.size = size # Update the buffer allocation if (size > old_size): self.buffer.extend(bytearray(size - old_size)) else: self.buffer = self.buffer[:size] def _add_buffer_data(self, buffer_offset, data_buffer): """Internal method to add data to the Buffer Only self.size bytes were allocated for the Buffer. Therefore, only take an offset from the start_byte (ie a relative address) for where to store the data in the Buffer. If the provided data is greater than specified Buffer size, then the data will be truncated. """ data_to_add_size = len(data_buffer) buffer_end_byte = buffer_offset + data_to_add_size # If the buffer size will be exceeded, then truncate the add if (buffer_end_byte > self.size): buffer_end_byte = self.size data_to_add_size = buffer_end_byte - buffer_offset # Update the tracker with the information # - Need to convert back to absolute addresses for tracker self._update_tracker((buffer_offset + self.start_byte), (buffer_end_byte + self.start_byte), data_to_add_size) # Add the data to the buffer if (data_to_add_size > 0): self.buffer[buffer_offset:buffer_end_byte] = data_buffer[:data_to_add_size] # Update the ocupancy of the buffer self.num_bytes = self._tracker_size() # Set the buffer complete flag self._set_buffer_complete() def _set_buffer_complete(self): """Internal method to set the complete flag on the Buffer.""" if (self.num_bytes == self.size): self.complete = True elif (self.num_bytes < self.size): self.complete = False else: print("WARNING: Buffer out of sync. Should never reach here.") def _tracker_size(self): """Internal method to get the Buffer size via the tracker.""" size = 0 for item in self.tracker: size += item[2] return size def _update_tracker(self, start_byte, end_byte, size): """Internal method to update the tracker.""" done = False # Don't add duplicate entries for item in self.tracker: if (start_byte == item[0]) and (end_byte == item[1]) and (size == item[2]): return # Can this update be added to one of the current tracker entries for item in self.tracker: if (start_byte == item[1]): item[1] += size item[2] += size done = True if not done: self.tracker.append({0:start_byte, 1:end_byte, 2:size}) # Try to compress the tracker self._compress_tracker() def _compress_tracker(self): """Internal method to compress the tracker.""" # If there is more than one item, try to compress them if (len(self.tracker) > 1): tracker = [] tmp_tracker = sorted(self.tracker, key=lambda k: k[0]) tmp_item = tmp_tracker[0] tracker.append(tmp_item) # For each remaining item if the start_byte equals the end_byte # of the start_item, then merge the items for item in tmp_tracker[1:]: if (item[0] == tmp_item[1]): tmp_item[1] = item[1] tmp_item[2] += item[2] else: tmp_item = item tracker.append(tmp_item) self.tracker = tracker def _find_missing_bytes(self): """Internal method to find the missing bytes using the tracker.""" ret_val = [] missing_bytes = self.size - self.num_bytes start = self.start_byte end = self.start_byte + self.size tmp_tracker = sorted(self.tracker, key=lambda k: k[0]) tracker_count = list(tmp_tracker) error = False if (missing_bytes != 0): # Iterate through all the items in the list and remove # them to build up the holes for item in tmp_tracker: # Process item but don't add a hole if (start == item[0]): start = item[1] tracker_count.remove(item) # There is a missing piece of the buffer to request elif ((start + missing_bytes) >= item[0]): tmp_size = item[0] - start if tmp_size < 0: print("WARNING: Issue with finding missing bytes.") print(" Size between items is negative.") print(" item : ({0}, {1}, {2})".format(item[0], item[1], item[2])) print(" start : {0}".format(start)) error = True else: missing_bytes -= tmp_size ret_val.append((start, item[0], tmp_size)) start = item[1] tracker_count.remove(item) # There was an error in the tracker else: print("WARNING: Issue with tracking missing bytes.") print(" Number of missing bytes does not cover hole between tracker items.") print(" Missing Bytes : {0}".format(missing_bytes)) tmp_size = item[0] - start missing_bytes = 0 ret_val.append((start, item[0], tmp_size)) start = item[1] tracker_count.remove(item) error = True if tracker_count: print("WARNING: Issue with finding missing bytes.") print(" Not all tracker items processed.") error = True if error: print(" Tracker : {0}".format(self.tracker)) print(" Tmp Tracker : {0}".format(tmp_tracker)) print(" Remaining Items : {0}".format(tracker_count)) print(" Locations : {0}".format(ret_val)) # Find any holes at the end of the buffer if (missing_bytes != 0): if (end != start): tmp_size = end - start ret_val.append((start, end, tmp_size)) missing_bytes -= tmp_size # Missing bytes, so print a warning if (missing_bytes != 0): print("WARNING: Could not find all missing bytes: {0}".format(missing_bytes)) return ret_val # End Class