From e0c40bf187d9e0ee0b604e9aa69a581a15343390 Mon Sep 17 00:00:00 2001
From: Marko Mecina <marko.mecina@univie.ac.at>
Date: Fri, 21 Oct 2022 11:24:27 +0200
Subject: [PATCH] migrate ce_decompress and associated functions from cheops
 CCS

---
 Ccs/ccs_function_lib.py | 313 ++++++++++++++++++++++++++++++++++++++++
 1 file changed, 313 insertions(+)

diff --git a/Ccs/ccs_function_lib.py b/Ccs/ccs_function_lib.py
index 093ef80..a816430 100644
--- a/Ccs/ccs_function_lib.py
+++ b/Ccs/ccs_function_lib.py
@@ -107,6 +107,12 @@ else:
 
 # Notify.init('cfl')
 
+# for CE and S13 collect
+ldt_minimum_ce_gap = 0.001
+ce_collect_timeout = 1
+last_ce_time = 0
+ce_decompression_on = False
+
 
 def _add_log_socket_handler():
     global logger
@@ -3983,6 +3989,313 @@ def interleave_lists(*args):
     return [i for j in zip(*args) for i in j]
 
 
+def create_format_model():
+    store = Gtk.ListStore(str)
+    for fmt in fmtlist.keys():
+        if fmt != 'bit*':
+            store.append([fmt])
+    for pers in personal_fmtlist:
+        store.append([pers])
+    return store
+
+
+################### quick copy of S13 get data and ce_decompress functionality from old CCS ####################
+import astropy.io.fits as pyfits
+import threading
+CEthread = None
+IFSWDIR = cfg.get('paths', 'obsw')
+if not IFSWDIR.endswith('/'):
+    IFSWDIR += '/'
+##
+#  Collect TM13 packets
+def collect_13(pool_name, starttime=0, endtime=None, start=None, end=None, join=True, collect_all=False,
+               sdu=None):
+    rows = get_pool_rows(pool_name)
+    if start is not None:
+        starttime = float(rows.filter(DbTelemetry.idx == start).first().timestamp[:-1])
+
+    if endtime is None:
+        endtime = get_last_pckt_time(pool_name, string=False)
+
+    if end is not None:
+        endtime = float(rows.filter(DbTelemetry.idx == end).first().timestamp[:-1])
+
+
+    if starttime is None or endtime is None:
+        raise ValueError('Specify start(time) and end(time)!')
+
+    ces = {}
+    # faster method to collect TM13 transfers already completed
+    tm_bounds = rows.filter(DbTelemetry.stc == 13, DbTelemetry.sst.in_([1, 3]),
+                            func.left(DbTelemetry.timestamp, func.length(DbTelemetry.timestamp) - 1) >= starttime,
+                            func.left(DbTelemetry.timestamp, func.length(DbTelemetry.timestamp) - 1) <= endtime
+                            ).order_by(DbTelemetry.idx)
+    if sdu:
+        tm_bounds = tm_bounds.filter(func.left(DbTelemetry.data, 1) == struct.pack('>B', sdu))
+
+    # quit if no start and end packet are found
+    if tm_bounds.count() < 2:
+        return {None: None}
+
+    tm_132 = rows.filter(DbTelemetry.stc == 13, DbTelemetry.sst == 2,
+                         func.left(DbTelemetry.timestamp, func.length(DbTelemetry.timestamp) - 1) > starttime,
+                         func.left(DbTelemetry.timestamp, func.length(DbTelemetry.timestamp) - 1) < endtime
+                         ).order_by(DbTelemetry.idx)
+    if sdu:
+        tm_132 = tm_132.filter(func.left(DbTelemetry.data, 1) == struct.pack('>B', sdu))
+
+    # make sure to start with a 13,1
+    while tm_bounds[0].sst != 1:
+        tm_bounds = tm_bounds[1:]
+        if len(tm_bounds) < 2:
+            return {None: None}
+    # and end with a 13,3
+    while tm_bounds[-1].sst != 3:
+        tm_bounds = tm_bounds[:-1]
+        if len(tm_bounds) < 2:
+            return {None: None}
+
+    if not collect_all:
+        tm_bounds = tm_bounds[:2]
+    else:
+        tm_bounds = tm_bounds[:]  # cast Query to list if not list already
+
+    # check for out of order 1s and 3s
+    idcs = [i.sst for i in tm_bounds]
+    outoforder = []
+    for i in range(len(idcs) - 1):
+        if idcs[i + 1] == idcs[i]:
+            if idcs[i] == 1:
+                outoforder.append(i)
+            elif idcs[i] == 3:
+                outoforder.append(i + 1)
+    dels = 0
+    for i in outoforder:
+        del tm_bounds[i - dels]
+        dels += 1
+    # check if start/end (1/3) strictly alternating
+    if not np.diff([i.sst for i in tm_bounds]).all():
+        print('Detected inconsistent transfers')
+        return {None: None}
+
+    k = 0
+    n = len(tm_bounds)
+    while k < n - 1:
+        i, j = tm_bounds[k:k + 2]
+        if i.sst == j.sst:
+            k += 1
+            continue
+        else:
+            pkts = [a.raw[21:-2] for a in tm_132.filter(DbTelemetry.idx > i.idx, DbTelemetry.idx < j.idx)]
+            if join:
+                ces[float(i.timestamp[:-1])] = i.raw[21:-2] + b''.join(pkts) + j.raw[21:-2]
+            else:
+                ces[float(i.timestamp[:-1])] = [i.raw[21:-2]] + [b''.join(pkts)] + [j.raw[21:-2]]
+            k += 2
+
+    # for (i, j) in zip(tm_bounds[::2], tm_bounds[1::2]):
+    #     pkts = [a.raw[21:-2] for a in tm_132.filter(DbTelemetry.idx > i.idx, DbTelemetry.idx < j.idx)]
+    #     if join:
+    #         ces[float(i.timestamp[:-1])] = i.raw[21:-2] + b''.join(pkts) + j.raw[21:-2]
+    #     else:
+    #         ces[float(i.timestamp[:-1])] = [i.raw[21:-2]] + [b''.join(pkts)] + [j.raw[21:-2]]
+
+    return ces
+
+    # def collect(pool_name, rows=rows, starttime=0., endtime=-1, join=False, dbcon=None):
+    #     tm13_1 = \
+    #         self.wait_for_tm(pool_name=pool_name, st=13, sst=1, reftime=starttime, max_wait=endtime - starttime,
+    #                          sid=sdu, sqlquery=rows, dbcon=dbcon)[1]
+    #     if tm13_1 is not None:
+    #         starttime2 = self.get_cuctime(tm13_1)
+    #     else:
+    #         return None, None
+    #     tm13_3 = \
+    #         self.wait_for_tm(pool_name=pool_name, st=13, sst=3, reftime=starttime2, max_wait=endtime - starttime2,
+    #                          sid=sdu, sqlquery=rows, dbcon=dbcon)[1]
+    #
+    #     if tm13_3 is None:
+    #         return None, None
+    #     dl_start, dl_end = self.get_cuctime(tm13_1), self.get_cuctime(tm13_3)
+    #     # pckts = rows.filter(DbTelemetry.stc == 13, DbTelemetry.sst == 2).order_by(DbTelemetry.seq).all()
+    #     pckts = rows.filter(DbTelemetry.stc == 13, DbTelemetry.sst == 2,
+    #                         func.left(DbTelemetry.timestamp, func.length(DbTelemetry.timestamp) - 1) <= dl_end,
+    #                         func.left(DbTelemetry.timestamp, func.length(DbTelemetry.timestamp) - 1) >= dl_start
+    #                         ).order_by(DbTelemetry.seq)
+    #     if sdu:
+    #         pckts = pckts.filter(func.left(DbTelemetry.data, 1) == struct.pack('>B', sdu))
+    #         # pckts = [pckt for pckt in pckts if pckt.raw[16]==sdu]
+    #         # pckts = self.Tm_filter_st(pckts, st=13, sst=2, sid=sdu)
+    #     if join:
+    #         buf = tm13_1[21:-2] + b''.join([pckt.raw[21:-2] for pckt in pckts]) + tm13_3[21:-2]
+    #         # buf = tm13_1[21:-2] + b''.join(
+    #         #     [pckt.raw[21:-2] for pckt in pckts if (dl_start <= float(pckt.timestamp[:-1])
+    #         #                                            <= dl_end)]) + tm13_3[21:-2]
+    #     else:
+    #         buf = [tm13_1[21:-2]] + [pckt.raw[21:-2] for pckt in pckts] + [tm13_3[21:-2]]
+    #         # buf = [tm13_1[21:-2]] + [pckt.raw[21:-2] for pckt in pckts if
+    #         #                          (dl_start <= float(pckt.timestamp[:-1]) <= dl_end)] \
+    #         #       + [tm13_3[21:-2]]
+    #         # arr = np.frombuffer(buf, dtype='>u2')
+    #     return buf, self.get_cuctime(tm13_1)
+    #
+    # ces = {}
+    # buf, time = collect(pool_name, starttime=starttime, endtime=endtime, join=join, dbcon=dbcon)
+    # ces[time] = buf
+    # if collect_all:
+    #     while buf is not None:
+    #         starttime = time + self.ldt_minimum_ce_gap
+    #         buf, time = collect(pool_name, starttime=starttime, endtime=endtime, join=join, dbcon=dbcon)
+    #         if buf is not None:
+    #             ces[time] = buf
+    # rows.session.close()
+    # return ces
+
+def dump_large_data(pool_name, starttime=0, endtime=None, outdir="", dump_all=False, sdu=None):
+    """
+    Dump 13,2 data to disk
+    @param pool_name:
+    @param starttime:
+    @param endtime:
+    @param outdir:
+    @param dump_all:
+    """
+    filedict = {}
+    ldt_dict = collect_13(pool_name, starttime=starttime, endtime=endtime, join=True, collect_all=dump_all,
+                               sdu=sdu)
+    for buf in ldt_dict:
+        if ldt_dict[buf] is None:
+            continue
+        obsid, time, ftime, ctr = struct.unpack('>IIHH', ldt_dict[buf][:12])
+        # if (outdir != "") and (not outdir.endswith("/")):
+        #     outdir += "/"
+        with open("{}LDT_{:03d}_{:010d}_{:06d}.ce".format(outdir, obsid, time, ctr), "wb") as fdesc:
+            fdesc.write(ldt_dict[buf])
+            filedict[buf] = fdesc.name
+    # return list(ldt_dict.keys())
+    return filedict
+
+def create_fits(data=None, header=None, filename=None):
+    hdulist = pyfits.HDUList()
+    hdu = pyfits.PrimaryHDU()
+    hdu.header = header
+    hdulist.append(hdu)
+
+    imagette_hdu = pyfits.ImageHDU()
+    stack_hdu = pyfits.ImageHDU()
+    margins = pyfits.ImageHDU()
+
+    hdulist.append(imagette_hdu)
+    hdulist.append(stack_hdu)
+    hdulist.append(margins)
+
+    if filename:
+        with open(filename, "wb") as fd:
+            hdulist.writeto(fd)
+
+    return hdulist
+
+
+def ce_decompress(pool_name='LIVE', outdir="", sdu=None, starttime=None, endtime=None):
+    global ce_decompression_on
+    global CEthread
+    global last_ce_time
+    # if outdir != "" and (not outdir.endswith("/")):
+    #     outdir += "/"
+    checkdir = '/'.join(outdir.split('/')[:-1])
+    if not os.path.exists(checkdir) and checkdir != "":
+        os.mkdir(checkdir)
+
+    thread = threading.Thread(target=_ce_decompress_worker,
+                              kwargs={'pool_name': pool_name,  'outdir': outdir, 'sdu': sdu, 'endtime': endtime},
+                              name="CeDecompression")
+    thread.daemon = True
+    CEthread = thread
+    if starttime is not None:
+        last_ce_time = starttime
+    ce_decompression_on = True
+    thread.start()
+    return thread
+
+def _ce_decompress_worker(pool_name, outdir="", sdu=None, endtime=None):
+    global ce_collect_timeout
+    global ldt_minimum_ce_gap
+    global last_ce_time
+    global ce_decompression_on
+    global IFSWDIR
+
+    def decompress(cefile):
+        logger.info("Decompressing {}".format(cefile))
+        fitspath = cefile[:-2] + 'fits'
+        if os.path.isfile(fitspath):
+            subprocess.run(["rm", fitspath])
+        subprocess.run([IFSWDIR + "CompressionEntity/build/DecompressCe", cefile, fitspath],
+                       stdout=open(cefile[:-2] + 'log', 'w'))
+
+    # first, get all TM13s already complete in pool
+    filedict = dump_large_data(pool_name=pool_name, starttime=last_ce_time, endtime=endtime, outdir=outdir,
+                                    dump_all=True, sdu=sdu)
+    for ce in filedict:
+        last_ce_time = ce
+        decompress(filedict[ce])
+
+    while ce_decompression_on:
+        filedict = dump_large_data(pool_name=pool_name, starttime=last_ce_time, endtime=None, outdir=outdir,
+                                        dump_all=False, sdu=sdu)
+        if len(filedict) == 0:
+            time.sleep(ce_collect_timeout)
+            continue
+        last_ce_time, cefile = list(filedict.items())[0]
+        decompress(cefile)
+        last_ce_time += ldt_minimum_ce_gap
+        time.sleep(ce_collect_timeout)
+
+def stop_ce_decompress():
+    global ce_decompression_on
+    ce_decompression_on = False
+
+def reset_ce_decompress(timestamp=0.0):
+    global last_ce_time
+    last_ce_time = timestamp
+
+def build_fits(basefits, newfits):
+    base = pyfits.open(basefits)
+    new = pyfits.open(newfits)
+    for hdu in range(len(base)):
+        base[hdu].data = np.concatenate([base[hdu].data, new[hdu].data])
+    base.writeto(basefits, overwrite=True)
+
+
+def convert_fullframe_to_cheopssim(fname):
+    """
+    Convert a fullframe (1076x1033) FITS to CHEOPS-SIM format
+    @param fname: Input FITS file
+    """
+    d = pyfits.open(fname)
+    full = np.array(np.round(d[0].data), dtype=np.uint16)
+    win_dict = {"SubArray": full[:, :1024, 28:28+1024],
+                "OverscanLeftImage": full[:, :1024, :4],
+                "BlankLeftImage": full[:, :1024, 4:4+8],
+                "DarkLeftImage": full[:, :1024, 12:28],
+                "DarkRightImage": full[:, :1024, 1052:1052+16],
+                "BlankRightImage": full[:, :1024, 1068:],
+                "DarkTopImage": full[:, 1024:-6, 28:-24],
+                "OverscanTopImage": full[:, -6:, 28:-24]}
+
+    hdulist = pyfits.HDUList()
+    hdulist.append(pyfits.PrimaryHDU())
+
+    for win in win_dict:
+        hdu = pyfits.ImageHDU(data=win_dict[win], name=win)
+        hdulist.append(hdu)
+
+    hdulist.append(pyfits.BinTableHDU(name="ImageMetaData"))
+
+    hdulist.writeto(fname[:-5] + '_CHEOPSSIM.fits')
+
+#############################################
+
+
 class TestReport:
 
     def __init__(self, filename, version, idb_version, gui=False, delimiter='|'):
-- 
GitLab