Source code for wlan_exp.transport.node

# -*- coding: utf-8 -*-
"""
------------------------------------------------------------------------------
Mango 802.11 Reference Design Experiments Framework - Transport Node 
------------------------------------------------------------------------------
License:   Copyright 2019 Mango Communications, Inc. All rights reserved.
           Use and distribution subject to terms in LICENSE.txt
------------------------------------------------------------------------------

This module provides class definition for Transport Node.

Functions (see below for more information):
    WlanExpTransportNode()        -- Base class for Transport nodes
    WlanExpTransportNodeFactory() -- Base class for creating a WlanExpTransportNode

Integer constants:
    NODE_TYPE, NODE_ID, NODE_HW_GEN, NODE_SERIAL_NUM, 
      NODE_FPGA_DNA -- Node hardware parameter constants 

If additional hardware parameters are needed for sub-classes of WlanExpTransportNode, 
please make sure that the values of these hardware parameters are not reused.

"""

from . import cmds
from . import exception as ex


__all__ = ['WlanExpTransportNode', 'WlanExpTransportNodeFactory']


# Node Parameter Identifiers
#     - The C counterparts are found in *_node.h
NODE_TYPE               = 0
NODE_ID                 = 1
NODE_HW_GEN             = 2
NODE_SERIAL_NUM         = 3
NODE_FPGA_DNA           = 4



class WlanExpTransportNode(object):
    """Base Class for Transport node.
    
    The Transport node represents one node in a network.  This class is the 
    primary interface for interacting with nodes by providing methods for 
    sending commands and checking status of nodes.
    
    By default, the base Transport node provides many useful node attributes
    as well as a transport component.
    
    Attributes:
        node_type            -- Unique type of the Transport node
        node_id              -- Unique identification for this node
        name                 -- User specified name for this node (supplied by user scripts)
        description          -- String description of this node (auto-generated)
        serial_number        -- Node's serial number, read from EEPROM on hardware
        fpga_dna             -- Node's FPGA'a unique identification (on select hardware)

        transport            -- Node's transport object
        transport_broadcast  -- Node's broadcast transport object
    """
    network_config           = None

    node_type                = None
    node_id                  = None
    name                     = None
    description              = None
    serial_number            = None
    sn_str                   = None
    fpga_dna                 = None

    transport                = None
    transport_broadcast      = None
    transport_tracker        = None
    
    def __init__(self, network_config=None):
        if network_config is not None:
            self.network_config = network_config
        else:
            from . import config

            self.network_config = config.NetworkConfiguration()

        self.transport_tracker = 0


    def __del__(self):
        """Clear the transport object to close any open socket connections
        in the event the node is deleted"""
        if self.transport:
            self.transport.transport_close()
            self.transport = None

        if self.transport_broadcast:
            self.transport_broadcast.transport_close()
            self.transport_broadcast = None


    def set_init_configuration(self, serial_number, node_id, node_name, 
                               ip_address, unicast_port, broadcast_port):
        """Set the initial configuration of the node."""
        from . import util
        import wlan_exp.platform as platform

        host_id      = self.network_config.get_param('host_id')
        tx_buf_size  = self.network_config.get_param('tx_buffer_size')
        rx_buf_size  = self.network_config.get_param('rx_buffer_size')
        tport_type   = self.network_config.get_param('transport_type')

        (sn, sn_str) = util.get_serial_number(serial_number)
        p = platform.lookup_platform_by_serial_num(serial_number)
        if p:
            self.platform_id = p.platform_id
        else:
            print('WARNING: no platform found for serial number {}'.format(serial_number))
            self.platform_id = -1

        if (tport_type == 'python'):
            from . import transport_eth_ip_udp_py as unicast_tp
            from . import transport_eth_ip_udp_py_broadcast as broadcast_tp

            if self.transport is None:
                self.transport = unicast_tp.TransportEthIpUdpPy()

            if self.transport_broadcast is None:
                self.transport_broadcast = broadcast_tp.TransportEthIpUdpPyBroadcast(self.network_config)
        else:
            print("Transport not defined\n")
        
        # Set Node information        
        self.node_id       = node_id
        self.name          = node_name
        self.serial_number = sn
        self.sn_str        = sn_str

        # Set Node Unicast Transport information
        self.transport.transport_open(tx_buf_size, rx_buf_size)
        self.transport.set_ip_address(ip_address)
        self.transport.set_unicast_port(unicast_port)
        self.transport.set_broadcast_port(broadcast_port)
        self.transport.set_src_id(host_id)
        self.transport.set_dest_id(node_id)

        # Set Node Broadcast Transport information
        self.transport_broadcast.transport_open(tx_buf_size, rx_buf_size)
        self.transport_broadcast.set_ip_address(ip_address)
        self.transport_broadcast.set_unicast_port(unicast_port)
        self.transport_broadcast.set_broadcast_port(broadcast_port)
        self.transport_broadcast.set_src_id(host_id)
        self.transport_broadcast.set_dest_id(0xFFFF)
        

    def configure_node(self, jumbo_frame_support=False):
        """Get remaining information from the node and set remaining parameters."""
        
        self.transport.ping(self)

        # Retrieve the hardware node's NODE_INFO and PLATFORM_NODE_INFO structs
        #  This method intentionally retrieves and applies the node_info structs
        #  in two steps. In other implementations the node_info structs might 
        #  be known a priori, and could be applied to the node object without
        #  any over-the-wire handshake
        hw_node_info = self.get_node_info()
        
        # hw_node_info is tuple of InfoStruct instances for the Node Info structs
        #  returned by the node. The first struct is always a platform-agnostic
        #  NODE_INFO as defined in info.py. Other structs in tuple are produced
        #  and consumed by platform code
        
        self.update_node_info(hw_node_info)

        # Set description
        self.description = self.__repr__()

    def update_node_info(self, node_info_struct):
        raise NotImplementedError('ERROR: superclass does not handle update_node_info!')

    def get_type_ids(self):
        """Get the type of the node. The node type identifies the node's
        platform ID and the software application IDs in CPU High and Low.
        This method returns the minimum info requires for the init flow
        to select the correct class of wlan_exp node object"""
        
        node_type_ids = self.send_cmd(cmds.NodeGetType())
        
        return node_type_ids


    def get_node_info(self):
        """Get the Hardware Information from the node."""
        return self.send_cmd(cmds.GetNodeInfo())
    
    def test_mtu(self, mtu):
        """Tests that the Node->Host link supports a given MTU. The current version
        of this method does not test the the Host->Node link due to a limitation in
        the Eth Rx handling via the wlan_mac_queue subsystem. Queue buffers are 
        limited to 4kB, sufficient for all current uses including all wlan_exp 
        host->node messages, but insufficent to test a full 9kB jumbo MTU
        """
        from wlan_exp.transport.exception import TransportError
        
        try:
            return self.send_cmd(cmds.TestTransportMTU(mtu))

        except TransportError:
            # TransportError catches timeout, either when node does not respond
            #  or node does responsd but host drops the response for being too
            #  big. Usually a TransportError should halt the script. In this
            #  special case we catch the timeout to return False so the init
            #  code can continue its node init using the last known-good MTU
            return False

    def set_max_resp_words(self, max_words):
        """Sets the maximum number of payload words the node may include
        in any response packet. The value must be derived from the MTU of the
        host, node, and network"""

        return self.send_cmd(cmds.TransportSetMaxRespWords(max_words))

    def set_name(self, name):
        """Set the name of the node.
        
        The name provided will affect the Python environment only 
        (ie it will update strings in child classes but will not be 
        transmitted to the node.)
            
        Args:
            name (str):  User provided name of the node.        
        """
        self.name        = name
        self.description = self.__repr__()



    # -------------------------------------------------------------------------
    # Commands for the Node
    # -------------------------------------------------------------------------
    def identify(self):
        """Identify the node
        
        The node will physically identify itself by:
        
          * Blinking the Hex Display (for approx 10 seconds)
          * Output Node ID and IP adress to UART output
        """
        self.send_cmd(cmds.NodeIdentify(self.sn_str))

    def ping(self):
        """'Ping' the node 
        
        Send an empty packet to the node via the transport to test connectivity
        between the host and the node.  This is the simplest command that can
        be processed by the node and is similar to the "ping" command used
        check network connectivity.
        """
        self.transport.ping(self, output=True)

    def get_temp(self):
        """Get the temperature of the node."""
        (curr_temp, _, _) = self.send_cmd(cmds.NodeGetTemperature()) # Min / Max temp not used
        return curr_temp

    def setup_network_inf(self):
        """Setup the transport network information for the node."""
        self.send_cmd_broadcast(cmds.NodeSetupNetwork(self))
        
    def reset_network_inf(self):
        """Reset the transport network information for the node."""
        #self.send_cmd_broadcast(cmds.NodeResetNetwork(self.serial_number))
        self.send_cmd_broadcast(cmds.NodeResetNetwork(self.sn_str))

    # -------------------------------------------------------------------------
    # Transmit / Receive methods for the Node
    # -------------------------------------------------------------------------
    def send_cmd(self, cmd, max_attempts=2, max_req_size=None, timeout=None):
        """Send the provided command.
        
        Args:
            cmd          -- Class of command to send
            max_attempts -- Maximum number of attempts to send a given command
            max_req_size -- Maximum request size (applys only to Buffer Commands)
            timeout      -- Maximum time to wait for a response from the node
        """
        from . import transport

        resp_type = cmd.get_resp_type()
        
        if  (resp_type == transport.TRANSPORT_NO_RESP):
            payload = cmd.serialize()
            self.transport.send(payload, robust=False)

        elif (resp_type == transport.TRANSPORT_RESP):
            resp = self._receive_resp(cmd, max_attempts, timeout)
            return cmd.process_resp(resp)

        elif (resp_type == transport.TRANSPORT_BUFFER):
            resp = self._receive_buffer(cmd, max_attempts, max_req_size, timeout)
            return cmd.process_resp(resp)

        else:
            raise ex.TransportError(self.transport, 
                                    "Unknown response type for command")


    def _receive_resp(self, cmd, max_attempts, timeout):
        """Internal method to receive a response for a given command payload"""
        from . import message

        reply = b''
        done = False
        resp = message.Resp()

        payload = cmd.serialize()
        self.transport.send(payload)

        while not done:
            try:
                reply = self.transport.receive(timeout)
                self._receive_success()
            except ex.TransportError:
                self._receive_failure()

                if self._receive_failure_exceeded(max_attempts):
                    raise ex.TransportError(self.transport, 
                              "Max retransmissions without reply from node")

                self.transport.send(payload)
            else:
                resp.deserialize(reply)
                done = True
                
        return resp


    def _receive_buffer(self, cmd, max_attempts, max_req_size, timeout):
        """Internal method to receive a buffer for a given command payload.
        
        Depending on the size of the buffer, the framework will split a
        single large request into multiple smaller requests based on the 
        max_req_size.  This is to:
          1) Minimize the probability that the OS drops a packet
          2) Minimize the time that the Ethernet interface on the node is busy 
             and cannot service other requests

        To see performance data, set the 'display_perf' flag to True.
        """
        from . import message

        display_perf    = False
        print_warnings  = True
        print_debug_msg = False
        
        reply           = b''

        start_byte      = cmd.get_buffer_start_byte()
        
        
        #FIXME: It's possible I got lost in the labrinth, but I *think*
        # total_size here could wind up being cmds.CMD_PARAM_LOG_GET_ALL_ENTRIES,
        # which is 0xFFFFFFFF. What in the sam hill does the while loop below
        # do for such a chonkster of a total_size?
        #FIXME FIXME: Oh, obviously, this is even more subtle. LogGetEvents()
        # notices if you set the size to CMD_PARAM_LOG_GET_ALL_ENTRIES and silently
        # overwrites the value with CMD_BUFFER_GET_SIZE_FROM_DATA, which is some
        # transport parameter? Anyway, that guy happens to also be 0xFFFFFFFFF
        # so alls mediocre that ends mediocre.
        total_size      = cmd.get_buffer_size()

        tmp_resp        = None
        resp            = None

        if max_req_size is not None:
            fragment_size = max_req_size
        else:
            fragment_size = total_size

        # To not hurt the performance of the transport, do not request more 
        # data than can fit in the RX buffer
        if (fragment_size > self.transport.rx_buffer_size):
            fragment_size = self.transport.rx_buffer_size
        
        # Allocate a complete response buffer        
        resp = message.Buffer(start_byte, total_size)
        resp.timestamp_in_hdr = cmd.timestamp_in_hdr
        
        if display_perf:
            import time
            print("Receive buffer")
            start_time = time.time()

        # If the transfer is more than the fragment size, then split the transaction
        if (total_size > fragment_size):
            size      = fragment_size
            start_idx = start_byte
            num_bytes = 0

            while (num_bytes < total_size):
                # Create fragmented command
                if (print_debug_msg):
                    print("\nFRAGMENT:  {0:10d}/{1:10d}\n".format(num_bytes, total_size))    
    
                # Handle the case of the last fragment
                if ((num_bytes + size) > total_size):
                    size = total_size - num_bytes

                # Update the command with the location and size of fragment
                cmd.update_start_byte(start_idx)
                cmd.update_size(size)
                
                # Send the updated command
                # FIXME: So this is recursive, yes? send_cmd is already in our 
                # callstack if we are here.
                tmp_resp = self.send_cmd(cmd)
                tmp_size = tmp_resp.get_buffer_size()
                
                if (tmp_size == size):
                    # Add the response to the buffer and increment loop variables
                    resp.merge(tmp_resp)
                    num_bytes += size
                    start_idx += size
                else:
                    #FIXME, either I'm misunderstanding or we always will end up 
                    #in this else at the end of a retrieval when trying to retrieve
                    # a log that has wrapped.
                    #This is what my above FIXME is about -- total_size here appears
                    #to be unaware of the magic 0xFFFFFFFF isn't really a size
                    #it should be enforcing.
                    
                    # Exit the loop because communication has totally failed for 
                    # the fragment and there is no point to request the next 
                    # fragment.  Only return the truncated buffer.
                    if (print_warnings):
                        msg  = "WARNING:  Command did not return a complete fragment.\n"
                        msg += "  Requested : {0:10d}\n".format(size)
                        msg += "  Received  : {0:10d}\n".format(tmp_size)
                        msg += "Returning truncated buffer."
                        print(msg)

                    break
        else:
            # Normal buffer receive flow
            payload = cmd.serialize()
            self.transport.send(payload)
    
            while not resp.is_buffer_complete():
                try:
                    reply = self.transport.receive(timeout)
                    self._receive_success()
                except ex.TransportError:
                    self._receive_failure()
                    if print_warnings:
                        print("WARNING:  Transport timeout.  Requesting missing data.")
                    
                    # If there is a timeout, then request missing part of the buffer
                    if self._receive_failure_exceeded(max_attempts):
                        if print_warnings:
                            print("ERROR:  Max re-transmissions without reply from node.")
                        raise ex.TransportError(self.transport, 
                                  "Max retransmissions without reply from node")
    
                    # Get the missing locations
                    locations = resp.get_missing_byte_locations()

                    if print_debug_msg:
                        print(resp)
                        print(resp.tracker)
                        print("Missing Locations in Buffer:")
                        print(locations)

                    # Send commands to fill in the buffer
                    for location in locations:
                        if (print_debug_msg):
                            print("\nLOCATION: {0:10d}    {1:10d}\n".format(location[0], location[2]))

                        # Update the command with the new location
                        cmd.update_start_byte(location[0])
                        cmd.update_size(location[2])
                        
                        if (location[2] < 0):
                            print("ERROR:  Issue with finding missing bytes in response:")
                            print("Response Tracker:")
                            print(resp.tracker)
                            print("\nMissing Locations:")
                            print(locations)
                            raise Exception()
                        
                        # Use the standard send to get a Buffer with missing data. 
                        # This avoids any race conditions when requesting 
                        # multiple missing locations.  Make sure that max_attempts
                        # are set to 1 for the re-request to not get in to an 
                        # infinite loop
                        try:
                            location_resp = self.send_cmd(cmd, max_attempts=max_attempts)
                            self._receive_success()
                        except ex.TransportError:
                            # Timed out on a re-request.  There is an error so
                            # just clean up the response and get out of the loop.
                            if print_warnings:
                                print("WARNING:  Transport timeout.  Returning truncated buffer.")
                                print("  Timeout requesting missing location: {1} bytes @ {0}".format(location[0], location[2]))
                                
                            self._receive_failure()
                            resp.trim()
                            return resp
                        
                        if print_debug_msg:
                            print("Adding Response:")
                            print(location_resp)
                            print(resp)                            
                        
                        # Add the response to the buffer
                        resp.merge(location_resp)

                        if print_debug_msg:
                            print("Buffer after merge:")
                            print(resp)
                            print(resp.tracker)
                        
                else:
                    resp.add_data_to_buffer(reply)

        # Trim the final buffer in case there were missing fragments
        resp.trim()
        
        if display_perf:
            print("    Receive time: {0}".format(time.time() - start_time))
        
        return resp
        
    
    def send_cmd_broadcast(self, cmd):
        """Send the provided command over the broadcast transport.

        Currently, broadcast commands cannot have a response.
        
        Args:
            cmd -- Class of command to send
        """
        self.transport_broadcast.send(payload=cmd.serialize())


    def receive_resp(self, timeout=None):
        """Return a list of responses that are sitting in the host's 
        receive queue.  It will empty the queue and return them all the 
        calling method."""
        from . import message

        output = []
        
        resp = self.transport.receive(timeout)
        
        if resp:
            # Create a list of response object if the list of bytes is a 
            # concatenation of many responses
            done = False
            
            while not done:
                msg = message.Resp()
                msg.deserialize(resp)
                resp_len = msg.sizeof()

                if resp_len < len(resp):
                    resp = resp[(resp_len):]
                else:
                    done = True
                    
                output.append(msg)
        
        return output



    # -------------------------------------------------------------------------
    # Transport Tracker
    # -------------------------------------------------------------------------
    def _receive_success(self):
        """Internal method - Successfully received a packet."""
        self.transport_tracker = 0

    
    def _receive_failure(self):
        """Internal method - Had a receive failure."""
        self.transport_tracker += 1


    def _receive_failure_exceeded(self, max_attempts):
        """Internal method - More recieve failures than max_attempts."""
        if (self.transport_tracker < max_attempts):
            return False
        else:
            return True


# End Class



class WlanExpTransportNodeFactory(WlanExpTransportNode):
    """Sub-class of Transport Node used to help with node configuration and setup.
    
    This class will maintian the dictionary of Node Types.  The dictionary
    contains the 32-bit Node Type as a key and the corresponding class name 
    as a value.
    
    To add new Node Types, you can sub-class WlanExpTransportNodeFactory and add your own 
    Node Types.
    
    Attributes:
        type_dict -- Dictionary of Node Types to class names
    """
    type_dict           = None


    def __init__(self, network_config=None):
        
        super(WlanExpTransportNodeFactory, self).__init__(network_config)
 
        # Initialize the list of node class/type mappingings
        #  New mappings will be added by the context which creates the
        #  instance of this factory class
        self.class_type_map = []
    
    def setup(self, node_dict):
        self.set_init_configuration(serial_number=node_dict['serial_number'],
                                    node_id=node_dict['node_id'], 
                                    node_name=node_dict['node_name'], 
                                    ip_address=node_dict['ip_address'], 
                                    unicast_port=node_dict['unicast_port'], 
                                    broadcast_port=node_dict['broadcast_port'])

    def create_node(self, network_config=None, network_reset=True):
        """Based on the Node Type, dynamically create and return the correct node."""
        
        node = None

        # Initialize the node network interface
        if network_reset:
            # Send broadcast command to reset the node network interface
            self.reset_network_inf()
    
            # Send broadcast command to initialize the node network interface
            self.setup_network_inf()

        try:
            # Send unicast command to get the node type
            node_type_ids = self.get_type_ids()
            
            # Lookup the appropriate Python class for this node type
            #  The return value is the actual class (not an instance) that can
            #  be used to create a new new object
            node_class = self.get_class_for_node_type_ids(node_type_ids)
        
            if node_class is not None:
                node = node_class()

                node.set_init_configuration(serial_number=self.sn_str,
                                            node_id=self.node_id,
                                            node_name=self.name,
                                            ip_address=self.transport.ip_address,
                                            unicast_port=self.transport.unicast_port,
                                            broadcast_port=self.transport.broadcast_port)

                # Store the platform/application IDs as node parameters
                #  These are verified against the IDs returned in the NODE_INFO during init
                node.platform_id = node_type_ids[0]
                node.high_sw_id = node_type_ids[1]
                node.low_sw_id = node_type_ids[2]
                
                # Copy the network_config MTU to the node's transport object
                #  The node itself will report an MTU during init, the lesser
                #  of the network and node MTUs will be used to set the final
                #  transport.mtu used to configure the node's max response size
                node.transport.mtu = network_config.get_param('mtu')

                msg  = "Initializing {0}".format(node.sn_str)
                if node.name is not None:
                    msg += " as {0}".format(node.name)
                print(msg)
                
            else:
                raise Exception('ERROR: no matching node class for node type IDs {}'.format(node_type_ids))

        except ex.TransportError as err:
            msg  = "ERROR:  Node {0}\n".format(self.sn_str)
            msg += "    Node is not responding.  Please ensure that the \n"
            msg += "    node is powered on and is properly configured.\n"
            print(msg)
            print(err)

        return node

    def add_node_type_class(self, class_type_mapping):
        """Adds a new node type / node class mapping. The argument must be a dictionary
           with type ID and class name keys. The factory instance searches the list of
           mappings to find the appropriate Python class for a given node during init."""

        self.class_type_map.append(class_type_mapping)

    def get_class_for_node_type_ids(self, node_type_ids):
        """Lookup the Python class for the given node type IDs. The default 
        mapping of type IDs to node classes is implemented in the factory 
        __init()__ method. User code can override/supplement the default
        mapping before calling init_nodes() to use custom node classes.
        
        Args:
            node_type_ids: 3-tuple of integer IDs:
                 (platform_id, high_sw_id, low_sw_id)
        """

        # In the current wlan_exp code the node class only depends on 
        #  the application running in CPU High (AP, STA, IBSS). Future
        #  extensions may add new node classes based on platform and
        #  CPU Low application
        
        # Find the first matching class/type mapping
        #  self.class_type_map is a list of dictionaries with the default 
        #  class/type maps inserted first.
        high_sw_id = node_type_ids[1]
        
        for c in self.class_type_map:
            if c['high_sw_id'] == high_sw_id:
                # Found matching type-class map
                # print('Node IDs platform={0}, high_sw={1}, low_sw={2} match class {3}'.format(node_type_ids[0], node_type_ids[1], node_type_ids[2], c['node_class']))
                return c['node_class']

        # No matching type-class found
        print('WARNING: no node class found for IDs platform={0}, high_sw={1}, low_sw{2}'.format(
                             node_type_ids[0], node_type_ids[1], node_type_ids[2]))

        return None