From 89f862ec266ed07ce29828e4b4d43241d37c10d5 Mon Sep 17 00:00:00 2001
From: Marko Mecina <marko.mecina@univie.ac.at>
Date: Thu, 17 Nov 2022 11:57:39 +0100
Subject: [PATCH] update load_to_memory function

---
 Ccs/ccs_function_lib.py | 99 +++++++++++++++++++++++++++++------------
 1 file changed, 70 insertions(+), 29 deletions(-)

diff --git a/Ccs/ccs_function_lib.py b/Ccs/ccs_function_lib.py
index 912a225..1644b00 100644
--- a/Ccs/ccs_function_lib.py
+++ b/Ccs/ccs_function_lib.py
@@ -2651,48 +2651,89 @@ def calc_param_crc(cmd, *args, no_check=False, hack_value=None):
     return crc(pdata[:-PEC_LEN])
 
 
-def tc_load_to_memory(data, memid, mempos, slicesize=1000, sleep=0., ack=None, pool_name='LIVE'):
+def load_to_memory(data, memid, memaddr, max_pkt_size=1000, sleep=0.125, ack=0b1001, pool_name='LIVE', tcname=None,
+                   progress=True, calc_crc=True, align=4):
     """
-    Function for loading large data to DPU memory. Splits the input _data_ into slices and sequentially sends them
-    to the specified location _memid_, _mempos_ by repeatedly calling the _Tcsend_DB_ function until
-    all _data_ is transferred.
-
-    :param data:  Data to be sent to memory. Can be a path to a file or bytestring or struct object
-    :param memid: Memory that data is sent to (e.g. 'DPU_RAM')
-    :param mempos: Memory start address the data should be patched to
-    :param slicesize: Size in bytes of the individual data slices, max=1000
-    :param sleep: Idle time in seconds between sending the individual TC packets
-    :param ack: Override the I-DB TC acknowledment value (4-bit binary string, e.g., '0b1011')
-    :param pool_name: connection through which to send the data
-    :return:
+    Function for loading data to DPU memory. Splits the input _data_ into slices and sequentially sends them
+    to the specified location _memid_, _mempos_ by repeatedly calling the _Tcsend_bytes_ function until
+    all _data_ is transferred. Data is zero-padded if not aligned to _align_ bytes.
+    @param data:
+    @param memid:
+    @param memaddr:
+    @param max_pkt_size:
+    @param sleep:
+    @param ack:
+    @param pool_name:
+    @param tcname:
+    @param progress:
+    @param calc_crc:
+    @param align:
+    @return:
     """
+
     if not isinstance(data, bytes):
         if isinstance(data, str):
             data = open(data, 'rb').read()
         else:
-            raise TypeError
+            raise TypeError('Data is not bytes or str')
 
-    cmd = get_tc_descr_from_stsst(6, 2)[0]
+    if align and (len(data) % align):
+        logger.warning('Data is not {}-byte aligned, padding.'.format(align))
+        data += bytes(align - (len(data) % align))
 
-    slices = [data[i:i + slicesize] for i in range(0, len(data), slicesize)]
-    if slicesize > (MAX_PKT_LEN - TC_HEADER_LEN - PEC_LEN):
-        logger.warning('SLICESIZE > {} bytes, this is not gonna work!'.format(MAX_PKT_LEN - TC_HEADER_LEN - PEC_LEN))
-    slicount = 1
+    # get service 6,2 info from MIB
+    apid, memid_ref, fmt, endspares = _get_upload_service_info(tcname)
+    pkt_overhead = TC_HEADER_LEN + struct.calcsize(fmt) + len(endspares) + PEC_LEN
+    payload_len = max_pkt_size - pkt_overhead
+
+    memid = get_mem_id(memid, memid_ref)
+
+    # get permanent pmgr handle to avoid requesting one for each packet
+    pmgr = _get_pmgr_handle(tc_pool=pool_name)
+
+    data_size = len(data)
+    startaddr = memaddr
+
+    upload_bytes = b''
+    pcnt = 0
+    ptot = None
+
+    slices = [data[i:i + payload_len] for i in range(0, len(data), payload_len)]
+    if (payload_len + pkt_overhead) > MAX_PKT_LEN:
+        logger.warning('PKTSIZE > {} bytes, this might not work!'.format(MAX_PKT_LEN))
 
     for sli in slices:
         t1 = time.time()
-        #parts = struct.unpack(len(sli) * 'B', sli)
-        #parts = *sli
-        Tcsend_DB(cmd, memid, mempos, len(sli), *sli, ack=ack, pool_name=pool_name)
-        # sys.stdout.write('%i / %i packets sent to {}\r'.format(slicount, len(slices), memid))
-        logger.info('%i / %i packets sent to {}'.format(slicount, len(slices), memid))
-        slicount += 1
-        mempos += len(sli)
+
+        # create PUS packet
+        packetdata = struct.pack(fmt, memid, startaddr, len(sli)) + sli + endspares
+        seq_cnt = counters.setdefault(apid, 0)
+        puspckt = Tcpack(data=packetdata, st=6, sst=2, apid=apid, sc=seq_cnt, ack=ack)
+
+        if len(puspckt) > MAX_PKT_LEN:
+            logger.warning('Packet length ({}) exceeding MAX_PKT_LEN of {} bytes!'.format(len(puspckt), MAX_PKT_LEN))
+
+        Tcsend_bytes(puspckt, pool_name=pool_name, pmgr_handle=pmgr)
+        # collect all uploaded segments for CRC at the end
+        upload_bytes += sli
+        pcnt += 1
+
+        if progress:
+            if ptot is None:
+                ptot = int(np.ceil(data_size / len(sli)))  # packets needed to transfer data
+            print('{}/{} packets sent\r'.format(pcnt, ptot), end='')
+
         dt = time.time() - t1
-        if dt < sleep:
-            time.sleep(sleep - dt)
+        time.sleep(max(sleep - dt, 0))
+
+        startaddr += len(sli)
+        counters[apid] += 1
 
-    return len(data)
+    print('\nUpload finished, {} bytes sent in {} packets.'.format(len(upload_bytes), pcnt))
+
+    if calc_crc:
+        # return total length of uploaded data  and CRC over entire uploaded data
+        return len(upload_bytes), crc(upload_bytes)
 
 
 def get_tc_descr_from_stsst(st, sst):
-- 
GitLab