import mysql.connector
from mysql.connector import errorcode
import boto3
import sys
import json
import os
import logging
from botocore.exceptions import ClientError

logger = logging.getLogger()
rds_client = boto3.client('rds')
sm_client = boto3.client('secretsmanager')
ssm_client = boto3.client('ssm')
Environment = os.environ.get('Environment')


def configure_logging() -> None:
    """
        Configure logging
    """
    logger.setLevel(logging.INFO)
    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.INFO)
    console_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
    logger.addHandler(console_handler)


def execute_sql(sql: str, connection: mysql.connector.connection) -> dict:
    """
    Executes an SQL statement and returns the result.

    Parameters:
        sql (str): The SQL statement to execute.
        connection (mysql.connector.connection): The MySQL database connection.

    Returns:
        dict: The result of the SQL execution. For 'SELECT' queries, a list of dictionaries representing the rows.
              For 'SHOW' queries, a single dictionary representing the row. For other statements, returns None.
    """
    try:
        with connection.cursor(dictionary=True) as cursor:
            logger.info(f"Executing SQL: {sql}")
            cursor.execute(sql)

            if sql.strip().upper().startswith('SELECT'):
                result = cursor.fetchall()
                logger.info(f"Returned {len(result)} Results for SQL: {sql}")
                return result
            elif sql.strip().upper().startswith('SHOW'):
                result = cursor.fetchone()
                logger.info(f"Returned Result for SHOW SQL: {result}")
                return result
            elif sql.strip().upper().startswith(('CREATE', 'DROP', 'INSERT', 'UPDATE', 'DELETE', 'ALTER', 'GRANT', 'REVOKE', 'TRUNCATE')):
                logger.info(f"{cursor.rowcount} record(s) affected")
                connection.commit()
                return {}
            else:
                logger.warning(f"Unhandled SQL type for statement: {sql}")
                return {}

    except mysql.connector.Error as err:
        if err.errno == errorcode.ER_ACCESS_DENIED_ERROR:
            logger.error("Access denied: Invalid user name or password")
        elif err.errno == errorcode.ER_BAD_DB_ERROR:
            logger.error("Database does not exist")
        else:
            logger.error(f"MySQL Error: {err}")
        raise

    except Exception as e:
        logger.error(f"Unexpected error: Could not execute SQL statement: {sql}")
        logger.error(e)
        raise


def connect_to_database(db_connection_data: dict) -> mysql.connector.connect:
    """
        Connect to the target database
        Parameters:
            db_connection_data (dict): The connection data for the target database
                db_connection_data.host = string
                db_connection_data.user = string
                db_connection_data.pass = string
                db_connection_data.name = string
    """
    try:
        connection = mysql.connector.connect(
            host=db_connection_data['host'],
            user=db_connection_data['user'],
            password=db_connection_data['pass'],
            database=db_connection_data['name']
        )
        logger.info(f"DB CONNECTION: {connection}")
        return connection
    except mysql.connector.Error as e:
        logger.error(f"ERROR: Unexpected error: Could not connect to Serverless {db_connection_data['db_name']} MySQL instance.")
        logger.error(e)
        sys.exit()


def check_if_table_exists(db_connection_data: dict, table_name: str) -> bool:
    """
        Check if a table exists in the target database
        Parameters:
            table_name (str): The name of the table to check for
            db_connection_data (dict): The connection data for the target database
                db_connection_data.host = string
                db_connection_data.user = string
                db_connection_data.pass = string
                db_connection_data.name = string
    """
    conn = connect_to_database(db_connection_data=db_connection_data)
    result = execute_sql(f"SHOW TABLES LIKE '{table_name}';", conn)
    logger.info(f"Check if table exists result: {result}")
    conn.close()
    if result is not None:
        return True
    else:
        return False


def list_db_users(db_connection_data: dict) -> dict:
    """
        Lists all the database users in the target database with a "sl_" prefix
        Parameters:
            db_connection_data (dict): The connection data for the target database
                db_connection_data.host = string
                db_connection_data.user = string
                db_connection_data.pass = string
                db_connection_data.name = string
    """
    conn = connect_to_database(db_connection_data=db_connection_data)
    results = execute_sql(f"SELECT * FROM mysql.db WHERE user like 'sl_%';", conn)
    conn.close()
    if results is not None:
        for row in results:
            logger.info(f"{row}")        
        return results


def create_db_user(db_connection_data: dict, user_details: dict) -> None:
    """
        Creates a database user and enforces the permissions on the role passed
        This will support downgrading and upgrading a user's role

        Parameters:
            db_connection_data (dict): The connection data for the target database
                db_connection_data.host = string
                db_connection_data.user = string
                db_connection_data.pass = string
                db_connection_data.name = string
            user_details (dict): The user details, including the name and role
                user_details.name = string
                user_details.role = string (allowed values: read, write, admin, trusted_admin)
    """

    # Define the permissions for each role
    permissions = {
        'read': 'SELECT',
        'write': 'SELECT, INSERT, UPDATE',
        'admin': 'SELECT, INSERT, UPDATE, DELETE, CREATE, REFERENCES, INDEX, ALTER, CREATE TEMPORARY TABLES, CREATE VIEW, SHOW VIEW, CREATE ROUTINE, ALTER ROUTINE, EXECUTE, EVENT, TRIGGER',
        'trusted_admin': 'ALL PRIVILEGES'
    }

    conn = connect_to_database(db_connection_data=db_connection_data)
    create_user_result = execute_sql(f"CREATE USER IF NOT EXISTS 'sl_{user_details['name']}'@'%' IDENTIFIED WITH AWSAuthenticationPlugin as 'RDS';", conn)
    logger.info(f"Create User Result: {create_user_result}")
    revoke_permissions_result = execute_sql(f"REVOKE ALL PRIVILEGES, GRANT OPTION FROM 'sl_{user_details['name']}'@'%';", conn) 
    logger.info(f"Revoke Permissions Result: {revoke_permissions_result}")
    grant_permissions_result = execute_sql(f"GRANT {permissions.get(user_details['role'])} ON {db_connection_data['name']}.* TO 'sl_{user_details['name']}'@'%';", conn)
    logger.info(f"Grant Permissions Result: {grant_permissions_result}")
    conn.close()


def delete_db_users(db_connection_data: dict, user_name: str) -> None:
    """
        Deletes a database user, if they don't exist, it's fine
    """
    logger.info(f"About to Connect to {db_connection_data['name']} in order to delete {user_name}")
    conn = connect_to_database(db_connection_data=db_connection_data)
    drop_user_result = execute_sql(f"DROP USER IF EXISTS '{user_name}'@'%';", conn)
    logger.info(f"Drop User Result: {drop_user_result}")
    conn.close()


def get_target_secret_arn(db: str) -> str:
    """
        Get the secret arn for the target database
    """
    logger.info(f'Getting secret arn for {db}')
    try:
        response = rds_client.describe_db_clusters(
            DBClusterIdentifier=f'esim-{Environment}-{db}-db-cluster-v3'
        )
        cluster = response.get('DBClusters')[0]
        secret_arn = cluster.get('MasterUserSecret').get('SecretArn')
        logger.info(f"{db} Secret ARN: {secret_arn}")
        return secret_arn

    except Exception as e:
        logger.error(f"ERROR: Unexpected error: Could not retrieve secret id for {db}.")
        logger.error(e)
        sys.exit()


def get_target_db_connection_data(db: str) -> dict:
    """
        Get the connection data for the target database
    """
    logger.info(f'Getting target db connection data: {db}')
    db_user = 'admin'

    password_secret_id = get_target_secret_arn(db)

    return {
        'host': get_database_endpoint(f'esim-{Environment}-{db}-db-cluster-v3'), 
        'user': db_user, 
        'pass': get_secret_value(f'{password_secret_id}', 'password'), 
        'name': f"{Environment}_{db}"
    }


def get_database_endpoint(db_identifier: str) -> str:
    try:
        response = rds_client.describe_db_clusters(
            DBClusterIdentifier=db_identifier
        )
        return response['DBClusters'][0]['Endpoint']

    except Exception as e:
        logger.error(f"ERROR: Unexpected error: Could not retrieve endpoint for {db_identifier}.")
        logger.error(e)
        sys.exit()


def get_secret_value(secret_arn: str, json_key: str) -> str:
    """
        Get a secret value from AWS Secrets Manager

        parameters:
            secret_arn (str): The ARN of the secret in Secrets Manager
            json_key (str): The key in the JSON object to retrieve

        returns:
            str: The value of the key in the JSON object
    """
    logger.info(f'Getting secret value for {secret_arn}, key: {json_key}')

    try:
        response = sm_client.get_secret_value(
          SecretId=secret_arn
        )
        raw_secret = response.get('SecretString')
        secret_json = json.loads(raw_secret)
        return secret_json[json_key]      
    except Exception as e:
        logger.error(f"ERROR: Unexpected error: Could not retrieve for {secret_arn}. Maybe the key {json_key} does not exist.")
        logger.error(e)
        sys.exit()


def get_ssm_value(parameter_name: str) -> dict:
    """
        Get a value from AWS Systems Manager Parameter Store

        parameters:
            parameter_name (str): The name of the parameter in Parameter Store

        returns:
            dict: The value of the parameter
    """

    try:
        response = ssm_client.get_parameter(
            Name=parameter_name,
            WithDecryption=True
        )

        parameter = response.get('Parameter')
        logger.info(f"Got SSM Parameter: {parameter}")
        value = parameter.get('Value')
        logger.info(f"Got SSM Value: {value}")
        db_users = json.loads(value)
        logger.info(f"JSON LOADS WORKED: {db_users}")
        return db_users
    except ClientError as e:
        logger.error(f"Client Error caught retrieving parameter {parameter_name}")
        logger.error(e)
        return {}
    except Exception as e:
        logger.error(f"General Exception caught getting: {parameter_name}")
        logger.error(e)
        return {}


def handler(event, context):
    """
        Lambda handler function

        - Gets list of database users from SSM
        - Loops through each database
            - Gets the target db connection data (querying RDS for endpoint and secrets manager arn, then gets the secret value)
            - Gets the list of existing users in the database
            - Finds which users aren't in config and deletes them
            - Goes through the users in config and ensures they're created (removes permissions for each and then reapplies permissions)
    """

    # Make sure we're logging correctly
    configure_logging()

    # For debugging
    logger.info(f"Context: {context}")
    logger.info(f"Event: {event}")

    # Get the configuration users from SSM
    db_users = get_ssm_value(f'/esim-app/{Environment}/database/users')
    logger.info(f"DB users from configuration: {db_users}")

    # Define the list of databases we want to manage
    databases = ['db_type_1', 'db_type_2', 'db_type_3', 'db_type_4', 'db_type_5']

    # Loop through each database
    for db in databases:
        #Get the target db connection data
        target_db_connection_data = get_target_db_connection_data(db)

        # Get the list of existing users in the database (we may want to delete some)
        existing_users = list_db_users(target_db_connection_data)
        
        # Find which users aren't in config and go delete them
        for deletion_candidate in existing_users:

            # Remove sl_ from the username
            deletion_candidate_user_name = deletion_candidate.get('User').removeprefix("sl_")

            # Is the existing db user in the list taken from ssm?
            delete_candidate = any(user.get("name") == deletion_candidate_user_name for user in db_users["users"])

            # If not in config, we should delete
            if delete_candidate:
                logger.info(f"sl_{deletion_candidate_user_name} is in config, leave them be")
            else: 
                logger.info(f"sl_{deletion_candidate_user_name} isn't in config, I should delete them here")
                delete_db_users(target_db_connection_data, f"sl_{deletion_candidate_user_name}")


        # Go through the users in config and ensure they're created
        for user in db_users.get('users'):
            logger.info(f"Creating: {user}")
            create_db_user(target_db_connection_data, user)


if __name__ == '__main__':
    handler(None, None)
