import csv
import logging
import os
import pathlib
import shutil
import sys
from io import StringIO
from typing import List, Sequence, Union
import geoalchemy2
import sqlalchemy
from sqlalchemy import MetaData, Table, func, select, text
from sqlalchemy.exc import InvalidRequestError
# ***************************** Database operations utils *****************************
[docs]
def get_table(
table_name: str,
schema: str,
conn: sqlalchemy.engine.Connectable,
logger: logging.Logger,
) -> sqlalchemy.Table:
"""Performs reflection to get a sqlalchemy Table object with metadata reflecting
the table found in the databse. Returns resulting Table object.
If the table is not found in the database, raises an error.
"""
meta = MetaData(schema=schema)
meta.bind = conn
try:
geoalchemy2
assert "geoalchemy2" in sys.modules
except (AssertionError, NameError):
logger.error(
"geoalchemy2 must be imported for geometry support in table reflection."
)
raise
try:
logger.info(f"Searching for table {schema}.{table_name}...")
meta.reflect(bind=conn, only=[table_name], views=True)
table = Table(table_name, meta, must_exist=True)
logger.info(f"Table {schema}.{table_name} found.")
except InvalidRequestError:
logger.error(
f"Table {schema}.{table_name} must exist. Make appropriate migrations "
+ "and try again."
)
raise
return table
[docs]
def delete(
tables: List[sqlalchemy.Table],
connection: sqlalchemy.engine.base.Connection,
logger: logging.Logger,
truncate: bool = False,
):
"""Delete tables.
Useful to wipe tables before re-inserting fresh data in ETL jobs."""
for table in tables:
count_statement = select(func.count()).select_from(table)
n = connection.execute(count_statement).fetchall()[0][0]
logger.info(f"Table {table.name} has {n} rows.")
if truncate:
tables_list = ", ".join(
[f'"{table.schema}"."{table.name}"' for table in tables]
)
logger.info(f"Truncating tables {tables_list}...")
connection.execute(text(f"TRUNCATE {tables_list}"))
else:
for table in tables:
logger.info(f"Deleting table {table.name}...")
connection.execute(table.delete())
[docs]
def delete_rows(
table: sqlalchemy.Table,
id_column: str,
ids_to_delete: Sequence,
connection: sqlalchemy.engine.base.Connection,
logger: logging.Logger,
):
"""Deletes all rows of a table whose id is in ``ids_to_delete``.
Args:
table (sqlalchemy.Table): table to remove rows from
id_column (str): name of the column in the table that contains ids to delete
ids (Sequence): list-like sequence of ids to look for in the table and delete
connection (sqlalchemy.engine.base.Connection): database connection
logger (logging.Logger): logger
"""
count_statement = select(func.count()).select_from(table)
n = connection.execute(count_statement).fetchall()[0][0]
if logger:
logger.info(f"Found existing table {table.name} with {n} rows.")
logger.info(f"Deleting some rows from table {table.name}...")
# to avoid certain type errors generated by psycopg2
ids_to_delete = list(map(str, ids_to_delete))
connection.execute(table.delete().where(table.c[id_column].in_(ids_to_delete)))
count_statement = select(func.count()).select_from(table)
n = connection.execute(count_statement).fetchall()[0][0]
if logger:
logger.info(f"Rows after deletion: {n}.")
[docs]
def psql_insert_copy(table, conn, keys, data_iter):
"""
Execute SQL statement inserting data
Parameters
----------
table : pandas.io.sql.SQLTable
conn : sqlalchemy.engine.Engine or sqlalchemy.engine.Connection
keys : list of str
Column names
data_iter : Iterable that iterates the values to be inserted
"""
# gets a DBAPI connection that can provide a cursor
dbapi_conn = conn.connection
with dbapi_conn.cursor() as cur:
s_buf = StringIO()
writer = csv.writer(s_buf)
writer.writerows(data_iter)
s_buf.seek(0)
columns = ", ".join('"{}"'.format(k) for k in keys)
if table.schema:
table_name = f'"{table.schema}"."{table.name}"'
else:
table_name = f'"{table.name}"'
sql = "COPY {} ({}) FROM STDIN WITH CSV".format(table_name, columns)
cur.copy_expert(sql=sql, file=s_buf)
[docs]
def move(
src_fp: pathlib.Path, dest_dirpath: pathlib.Path, if_exists: str = "raise"
) -> None:
"""Moves a file to another directory. If the destination directory
does not exist, it is created, as well as all intermediate directories."""
if not dest_dirpath.exists():
os.makedirs(dest_dirpath)
try:
shutil.move(src_fp.as_posix(), dest_dirpath.as_posix())
except shutil.Error:
if if_exists == "raise":
raise
elif if_exists == "replace":
os.remove(dest_dirpath / src_fp.name)
shutil.move(src_fp.as_posix(), dest_dirpath.as_posix())
else:
raise ValueError(f"if_exists must be 'raise' or 'replace', got {if_exists}")
[docs]
def remove_file(fp: Union[str, pathlib.Path], ignore_errors: bool = True):
try:
os.remove(fp)
except Exception:
if not ignore_errors:
raise