From 87825e7ea4e0a5b109e537363f16bb5b5e470e73 Mon Sep 17 00:00:00 2001
From: Marko Mecina <marko.mecina@univie.ac.at>
Date: Wed, 9 Aug 2023 16:23:28 +0200
Subject: [PATCH] compatibility fixes for SQLAlchemy 2.0

---
 Ccs/database/tm_db.py   | 23 ++++++++++++++++-------
 Ccs/tools/import_mib.py |  7 ++++---
 2 files changed, 20 insertions(+), 10 deletions(-)

diff --git a/Ccs/database/tm_db.py b/Ccs/database/tm_db.py
index 8d14fe4..0720f73 100644
--- a/Ccs/database/tm_db.py
+++ b/Ccs/database/tm_db.py
@@ -10,6 +10,7 @@ from sqlalchemy import (
     Column, Integer, Boolean, Unicode, Index, UniqueConstraint, ForeignKey, create_engine, engine)
 from sqlalchemy.dialects.mysql import VARBINARY
 from sqlalchemy.orm import sessionmaker, relationship, scoped_session
+from sqlalchemy.sql import text
 # from sqlalchemy.orm.session import Session
 
 # Use for SQlite
@@ -232,8 +233,6 @@ class FEEDataTelemetryPool(FEEDATA_BASE):  # type: ignore
 def gen_mysql_conn_str(user=config_db.user, pw=config_db.pw, host=config_db.host, schema=''):
     return engine.url.URL.create(drivername='mysql', username=user, password=pw, host=host, database=schema)
 
-#SQLSOCKET=''
-
 
 def create_storage_db(protocol='PUS', force=False):
     if protocol.upper() not in ['PUS', 'RMAP', 'FEEDATA', 'ALL']:
@@ -243,8 +242,8 @@ def create_storage_db(protocol='PUS', force=False):
         print('Creating schema "{}" for {} data storage...'.format(config_db.storage_schema_name, protocol.upper()))
         _engine = create_engine(gen_mysql_conn_str(), echo="-v" in sys.argv)
         if force:
-            _engine.execute('DROP SCHEMA IF EXISTS {}'.format(config_db.storage_schema_name))
-        _engine.execute('CREATE SCHEMA IF NOT EXISTS {}'.format(config_db.storage_schema_name))
+            _engine.execute(text('DROP SCHEMA IF EXISTS {}'.format(config_db.storage_schema_name)))
+        _engine.execute(text('CREATE SCHEMA IF NOT EXISTS {}'.format(config_db.storage_schema_name)))
         _engine.dispose()
         _engine = create_engine(gen_mysql_conn_str(schema=config_db.storage_schema_name), echo="-v" in sys.argv)
         for protocol in protocols:
@@ -254,8 +253,8 @@ def create_storage_db(protocol='PUS', force=False):
         print('Creating schema "{}" for {} data storage...'.format(config_db.storage_schema_name, protocol.upper()))
         _engine = create_engine(gen_mysql_conn_str(), echo="-v" in sys.argv)
         if force:
-            _engine.execute('DROP SCHEMA IF EXISTS {}'.format(config_db.storage_schema_name))
-        _engine.execute('CREATE SCHEMA IF NOT EXISTS {}'.format(config_db.storage_schema_name))
+            _engine.execute(text('DROP SCHEMA IF EXISTS {}'.format(config_db.storage_schema_name)))
+        _engine.execute(text('CREATE SCHEMA IF NOT EXISTS {}'.format(config_db.storage_schema_name)))
         _engine.dispose()
         _engine = create_engine(gen_mysql_conn_str(schema=config_db.storage_schema_name), echo="-v" in sys.argv)
         protocols[protocol.upper()][1].metadata.create_all(_engine)
@@ -280,10 +279,20 @@ def scoped_session_maker(db_schema, idb_version=None):
         return
     _engine = create_engine(gen_mysql_conn_str(schema=schema), echo="-v" in sys.argv, pool_size=15)
     session_factory = sessionmaker(bind=_engine)
-    scoped_session_factory = scoped_session(session_factory)
+    # scoped_session_factory = scoped_session(session_factory)
+    scoped_session_factory = scoped_session_v2(session_factory)
     return scoped_session_factory
 
 
+class scoped_session_v2(scoped_session):
+    """
+    Wrapper class to cast SQL query statement string to TextClause before execution, as this is required since SQLAlchemy 2.0.
+    """
+
+    def execute(self, x, *args, **kwargs):
+        return super().execute(text(x), *args, **kwargs)
+
+
 # def load_telemetry_file(dummy: str) -> None:
 #     '''Loads a telemetry dumpfile in the database, populating
 #     the tm_pool and tm tables.
diff --git a/Ccs/tools/import_mib.py b/Ccs/tools/import_mib.py
index b1514bd..ba316d9 100755
--- a/Ccs/tools/import_mib.py
+++ b/Ccs/tools/import_mib.py
@@ -11,6 +11,7 @@ import getpass
 
 from sqlalchemy import create_engine
 from sqlalchemy.orm import sessionmaker
+from sqlalchemy.sql import text
 
 
 sdir = os.path.dirname(os.path.abspath(__file__))
@@ -62,11 +63,11 @@ def create_schema():
 
     # delete database schema
     print('...drop schema {}'.format(DBNAME))
-    s.execute('DROP SCHEMA IF EXISTS {}'.format(DBNAME))
+    s.execute(text('DROP SCHEMA IF EXISTS {}'.format(DBNAME)))
 
     # create database schema
     print('...create schema {}'.format(DBNAME))
-    s.execute(open(WBSQL).read())
+    s.execute(text(open(WBSQL).read()))
     s.close()
 
 
@@ -86,7 +87,7 @@ def import_mib():
         rows = [('"' + i.replace('\t', '","').strip() + '"').replace('""', 'DEFAULT') for i in mfile]
         try:
             for row in rows:
-                s.execute('INSERT IGNORE INTO {} VALUES ({})'.format(fn[:-4], row))  # IGNORE truncates too long strings
+                s.execute(text('INSERT IGNORE INTO {} VALUES ({})'.format(fn[:-4], row)))  # IGNORE truncates too long strings
         except Exception as err:
             s.rollback()
             s.close()
-- 
GitLab