#!/usr/bin/env python3

"""
PostgreSQL Health Monitoring Script for RateRight
Comprehensive enterprise-grade database health checks
Run this script regularly to monitor database health
"""

import os
import sys
import json
import psycopg2
from datetime import datetime, timedelta
from typing import Dict, List, Tuple, Any

# Configuration
DATABASE_URL = os.environ.get('DATABASE_URL')
WARNING_THRESHOLDS = {
    'connection_usage_percent': 80,
    'disk_usage_percent': 80,
    'cache_hit_ratio': 90,
    'replication_lag_bytes': 10485760,  # 10MB
    'long_running_queries_minutes': 5,
    'table_bloat_percent': 20,
    'index_bloat_percent': 30,
    'deadlocks_per_hour': 5,
    'failed_logins_per_hour': 10
}

CRITICAL_THRESHOLDS = {
    'connection_usage_percent': 95,
    'disk_usage_percent': 95,
    'cache_hit_ratio': 80,
    'replication_lag_bytes': 104857600,  # 100MB
    'long_running_queries_minutes': 15,
    'table_bloat_percent': 40,
    'index_bloat_percent': 50,
    'deadlocks_per_hour': 20,
    'failed_logins_per_hour': 50
}


class DatabaseHealthMonitor:
    """Enterprise-grade PostgreSQL health monitoring"""
    
    def __init__(self, database_url: str):
        """Initialize database connection"""
        self.database_url = database_url
        self.conn = None
        self.health_report = {
            'timestamp': datetime.utcnow().isoformat(),
            'status': 'OK',
            'checks': {},
            'warnings': [],
            'critical': [],
            'metrics': {}
        }
    
    def connect(self) -> bool:
        """Establish database connection"""
        try:
            self.conn = psycopg2.connect(self.database_url)
            self.conn.autocommit = True
            return True
        except Exception as e:
            self.health_report['status'] = 'CRITICAL'
            self.health_report['critical'].append(f"Database connection failed: {str(e)}")
            return False
    
    def check_basic_connectivity(self) -> Dict[str, Any]:
        """Check basic database connectivity"""
        try:
            cursor = self.conn.cursor()
            cursor.execute("SELECT version(), current_database(), current_user, pg_is_in_recovery()")
            result = cursor.fetchone()
            
            return {
                'status': 'OK',
                'version': result[0],
                'database': result[1],
                'user': result[2],
                'is_replica': result[3]
            }
        except Exception as e:
            return {'status': 'FAILED', 'error': str(e)}
    
    def check_connection_usage(self) -> Dict[str, Any]:
        """Monitor connection pool usage"""
        try:
            cursor = self.conn.cursor()
            cursor.execute("""
                SELECT 
                    current_setting('max_connections')::int as max_connections,
                    COUNT(*) as active_connections,
                    COUNT(*) FILTER (WHERE state = 'active') as active_queries,
                    COUNT(*) FILTER (WHERE state = 'idle') as idle_connections,
                    COUNT(*) FILTER (WHERE state = 'idle in transaction') as idle_in_transaction,
                    COUNT(*) FILTER (WHERE wait_event_type IS NOT NULL) as waiting_connections
                FROM pg_stat_activity
                WHERE pid != pg_backend_pid()
            """)
            result = cursor.fetchone()
            
            max_conn = result[0]
            active_conn = result[1]
            usage_percent = (active_conn / max_conn) * 100
            
            status = 'OK'
            if usage_percent >= CRITICAL_THRESHOLDS['connection_usage_percent']:
                status = 'CRITICAL'
                self.health_report['critical'].append(
                    f"Connection usage critical: {usage_percent:.1f}% ({active_conn}/{max_conn})"
                )
            elif usage_percent >= WARNING_THRESHOLDS['connection_usage_percent']:
                status = 'WARNING'
                self.health_report['warnings'].append(
                    f"Connection usage high: {usage_percent:.1f}% ({active_conn}/{max_conn})"
                )
            
            return {
                'status': status,
                'max_connections': max_conn,
                'active_connections': active_conn,
                'active_queries': result[2],
                'idle_connections': result[3],
                'idle_in_transaction': result[4],
                'waiting_connections': result[5],
                'usage_percent': usage_percent
            }
        except Exception as e:
            return {'status': 'FAILED', 'error': str(e)}
    
    def check_disk_usage(self) -> Dict[str, Any]:
        """Monitor disk space usage"""
        try:
            cursor = self.conn.cursor()
            cursor.execute("""
                SELECT 
                    pg_database_size(current_database()) as db_size,
                    pg_size_pretty(pg_database_size(current_database())) as db_size_pretty,
                    SUM(pg_total_relation_size(oid)) as tables_size,
                    SUM(pg_indexes_size(oid)) as indexes_size
                FROM pg_class
                WHERE relkind IN ('r', 'm')
            """)
            result = cursor.fetchone()
            
            # Get table sizes
            cursor.execute("""
                SELECT 
                    schemaname || '.' || tablename as table_name,
                    pg_size_pretty(pg_total_relation_size((schemaname||'.'||tablename)::regclass)) as total_size,
                    pg_total_relation_size((schemaname||'.'||tablename)::regclass) as size_bytes
                FROM pg_tables
                WHERE schemaname = 'public'
                ORDER BY pg_total_relation_size((schemaname||'.'||tablename)::regclass) DESC
                LIMIT 10
            """)
            top_tables = cursor.fetchall()
            
            return {
                'status': 'OK',
                'database_size': result[0],
                'database_size_pretty': result[1],
                'tables_size': result[2],
                'indexes_size': result[3],
                'top_tables': [
                    {'name': t[0], 'size': t[1], 'bytes': t[2]} 
                    for t in top_tables
                ]
            }
        except Exception as e:
            return {'status': 'FAILED', 'error': str(e)}
    
    def check_cache_performance(self) -> Dict[str, Any]:
        """Monitor cache hit ratios"""
        try:
            cursor = self.conn.cursor()
            
            # Overall cache hit ratio
            cursor.execute("""
                SELECT 
                    sum(heap_blks_read) as heap_read,
                    sum(heap_blks_hit) as heap_hit,
                    CASE 
                        WHEN sum(heap_blks_hit) + sum(heap_blks_read) = 0 THEN 0
                        ELSE (sum(heap_blks_hit) * 100.0 / (sum(heap_blks_hit) + sum(heap_blks_read)))
                    END as cache_hit_ratio
                FROM pg_statio_user_tables
            """)
            cache_result = cursor.fetchone()
            cache_hit_ratio = cache_result[2] or 0
            
            # Index cache hit ratio
            cursor.execute("""
                SELECT 
                    sum(idx_blks_read) as idx_read,
                    sum(idx_blks_hit) as idx_hit,
                    CASE 
                        WHEN sum(idx_blks_hit) + sum(idx_blks_read) = 0 THEN 0
                        ELSE (sum(idx_blks_hit) * 100.0 / (sum(idx_blks_hit) + sum(idx_blks_read)))
                    END as idx_cache_hit_ratio
                FROM pg_statio_user_indexes
            """)
            idx_result = cursor.fetchone()
            idx_cache_hit_ratio = idx_result[2] or 0
            
            status = 'OK'
            if cache_hit_ratio < CRITICAL_THRESHOLDS['cache_hit_ratio']:
                status = 'CRITICAL'
                self.health_report['critical'].append(
                    f"Cache hit ratio critical: {cache_hit_ratio:.1f}%"
                )
            elif cache_hit_ratio < WARNING_THRESHOLDS['cache_hit_ratio']:
                status = 'WARNING'
                self.health_report['warnings'].append(
                    f"Cache hit ratio low: {cache_hit_ratio:.1f}%"
                )
            
            return {
                'status': status,
                'heap_blocks_read': cache_result[0],
                'heap_blocks_hit': cache_result[1],
                'cache_hit_ratio': cache_hit_ratio,
                'index_blocks_read': idx_result[0],
                'index_blocks_hit': idx_result[1],
                'index_cache_hit_ratio': idx_cache_hit_ratio
            }
        except Exception as e:
            return {'status': 'FAILED', 'error': str(e)}
    
    def check_long_running_queries(self) -> Dict[str, Any]:
        """Check for long-running queries"""
        try:
            cursor = self.conn.cursor()
            cursor.execute("""
                SELECT 
                    pid,
                    usename,
                    application_name,
                    state,
                    query_start,
                    state_change,
                    EXTRACT(EPOCH FROM (now() - query_start)) / 60 as runtime_minutes,
                    LEFT(query, 100) as query_snippet
                FROM pg_stat_activity
                WHERE state != 'idle'
                    AND query NOT LIKE '%pg_stat_activity%'
                    AND query_start < now() - interval '%s minutes'
                ORDER BY query_start
                LIMIT 10
            """, (WARNING_THRESHOLDS['long_running_queries_minutes'],))
            
            long_queries = cursor.fetchall()
            
            status = 'OK'
            critical_queries = []
            warning_queries = []
            
            for query in long_queries:
                runtime = query[6]
                if runtime > CRITICAL_THRESHOLDS['long_running_queries_minutes']:
                    critical_queries.append({
                        'pid': query[0],
                        'user': query[1],
                        'runtime_minutes': runtime,
                        'query': query[7]
                    })
                else:
                    warning_queries.append({
                        'pid': query[0],
                        'user': query[1],
                        'runtime_minutes': runtime,
                        'query': query[7]
                    })
            
            if critical_queries:
                status = 'CRITICAL'
                self.health_report['critical'].append(
                    f"{len(critical_queries)} queries running over {CRITICAL_THRESHOLDS['long_running_queries_minutes']} minutes"
                )
            elif warning_queries:
                status = 'WARNING'
                self.health_report['warnings'].append(
                    f"{len(warning_queries)} queries running over {WARNING_THRESHOLDS['long_running_queries_minutes']} minutes"
                )
            
            return {
                'status': status,
                'critical_queries': critical_queries,
                'warning_queries': warning_queries,
                'total_long_queries': len(long_queries)
            }
        except Exception as e:
            return {'status': 'FAILED', 'error': str(e)}
    
    def check_table_bloat(self) -> Dict[str, Any]:
        """Check for table and index bloat"""
        try:
            cursor = self.conn.cursor()
            
            # Check table bloat
            cursor.execute("""
                WITH constants AS (
                    SELECT current_setting('block_size')::numeric AS bs, 23 AS hdr, 4 AS ma
                ),
                bloat_info AS (
                    SELECT 
                        schemaname,
                        tablename,
                        cc.relpages,
                        bs,
                        CEIL((cc.reltuples * ((datahdr + ma - 
                            CASE WHEN datahdr % ma = 0 THEN ma ELSE datahdr % ma END) + nullhdr2 + 4)) / (bs - 20::float)) AS otta
                    FROM (
                        SELECT 
                            schemaname,
                            tablename,
                            (datawidth + (hdr + ma - (CASE WHEN hdr % ma = 0 THEN ma ELSE hdr % ma END)))::numeric AS datahdr,
                            (maxfracsum * (nullhdr + ma - (CASE WHEN nullhdr % ma = 0 THEN ma ELSE nullhdr % ma END))) AS nullhdr2
                        FROM (
                            SELECT 
                                schemaname,
                                tablename,
                                hdr,
                                ma,
                                bs,
                                SUM((1 - null_frac) * avg_width) AS datawidth,
                                MAX(null_frac) AS maxfracsum,
                                hdr + (
                                    SELECT 1 + COUNT(*) / 8
                                    FROM pg_stats s2
                                    WHERE null_frac != 0 AND s2.schemaname = s.schemaname AND s2.tablename = s.tablename
                                ) AS nullhdr
                            FROM pg_stats s, constants
                            WHERE schemaname = 'public'
                            GROUP BY 1, 2, 3, 4, 5
                        ) AS foo
                    ) AS rs
                    JOIN pg_class cc ON cc.relname = rs.tablename
                    JOIN pg_namespace nn ON cc.relnamespace = nn.oid AND nn.nspname = rs.schemaname
                    WHERE cc.relpages > 0
                ),
                table_bloat AS (
                    SELECT 
                        schemaname,
                        tablename,
                        relpages::bigint * bs AS real_size,
                        otta::bigint * bs AS expected_size,
                        CASE WHEN relpages > 0 
                            THEN ((relpages - otta) * 100.0 / relpages)
                            ELSE 0 
                        END AS bloat_pct
                    FROM bloat_info
                    WHERE relpages > otta
                )
                SELECT 
                    tablename,
                    pg_size_pretty(real_size) as actual_size,
                    pg_size_pretty(expected_size) as expected_size,
                    ROUND(bloat_pct, 1) as bloat_percent
                FROM table_bloat
                WHERE bloat_pct > %s
                ORDER BY bloat_pct DESC
                LIMIT 10
            """, (WARNING_THRESHOLDS['table_bloat_percent'],))
            
            bloated_tables = cursor.fetchall()
            
            status = 'OK'
            if bloated_tables:
                max_bloat = max(t[3] for t in bloated_tables)
                if max_bloat > CRITICAL_THRESHOLDS['table_bloat_percent']:
                    status = 'CRITICAL'
                    self.health_report['critical'].append(
                        f"Critical table bloat detected: {bloated_tables[0][0]} at {max_bloat}%"
                    )
                else:
                    status = 'WARNING'
                    self.health_report['warnings'].append(
                        f"Table bloat detected: {len(bloated_tables)} tables over {WARNING_THRESHOLDS['table_bloat_percent']}%"
                    )
            
            return {
                'status': status,
                'bloated_tables': [
                    {
                        'table': t[0],
                        'actual_size': t[1],
                        'expected_size': t[2],
                        'bloat_percent': t[3]
                    } for t in bloated_tables
                ]
            }
        except Exception as e:
            return {'status': 'FAILED', 'error': str(e)}
    
    def check_replication_status(self) -> Dict[str, Any]:
        """Check replication lag and status"""
        try:
            cursor = self.conn.cursor()
            
            # Check if this is a primary or replica
            cursor.execute("SELECT pg_is_in_recovery()")
            is_replica = cursor.fetchone()[0]
            
            if is_replica:
                # Check replication lag on replica
                cursor.execute("""
                    SELECT 
                        EXTRACT(EPOCH FROM (now() - pg_last_xact_replay_timestamp())) as replication_lag_seconds,
                        pg_last_wal_receive_lsn() - pg_last_wal_replay_lsn() as replication_lag_bytes
                """)
                result = cursor.fetchone()
                
                lag_seconds = result[0] or 0
                lag_bytes = result[1] or 0
                
                status = 'OK'
                if lag_bytes > CRITICAL_THRESHOLDS['replication_lag_bytes']:
                    status = 'CRITICAL'
                    self.health_report['critical'].append(
                        f"Replication lag critical: {lag_bytes} bytes"
                    )
                elif lag_bytes > WARNING_THRESHOLDS['replication_lag_bytes']:
                    status = 'WARNING'
                    self.health_report['warnings'].append(
                        f"Replication lag high: {lag_bytes} bytes"
                    )
                
                return {
                    'status': status,
                    'is_replica': True,
                    'replication_lag_seconds': lag_seconds,
                    'replication_lag_bytes': lag_bytes
                }
            else:
                # Check replication slots on primary
                cursor.execute("""
                    SELECT 
                        slot_name,
                        slot_type,
                        active,
                        pg_wal_lsn_diff(pg_current_wal_lsn(), restart_lsn) as lag_bytes
                    FROM pg_replication_slots
                """)
                slots = cursor.fetchall()
                
                return {
                    'status': 'OK',
                    'is_replica': False,
                    'replication_slots': [
                        {
                            'name': s[0],
                            'type': s[1],
                            'active': s[2],
                            'lag_bytes': s[3]
                        } for s in slots
                    ]
                }
        except Exception as e:
            return {'status': 'FAILED', 'error': str(e)}
    
    def check_security_events(self) -> Dict[str, Any]:
        """Check for security-related events"""
        try:
            cursor = self.conn.cursor()
            
            # Check for failed login attempts
            cursor.execute("""
                SELECT 
                    COUNT(*) as failed_logins
                FROM pg_stat_database
                WHERE datname = current_database()
                    AND xact_rollback > 0
            """)
            failed_logins = cursor.fetchone()[0]
            
            # Check for deadlocks
            cursor.execute("""
                SELECT 
                    deadlocks
                FROM pg_stat_database
                WHERE datname = current_database()
            """)
            deadlocks = cursor.fetchone()[0]
            
            status = 'OK'
            if deadlocks > CRITICAL_THRESHOLDS['deadlocks_per_hour']:
                status = 'CRITICAL'
                self.health_report['critical'].append(
                    f"High deadlock rate: {deadlocks} deadlocks detected"
                )
            elif deadlocks > WARNING_THRESHOLDS['deadlocks_per_hour']:
                status = 'WARNING'
                self.health_report['warnings'].append(
                    f"Deadlocks detected: {deadlocks}"
                )
            
            return {
                'status': status,
                'failed_logins': failed_logins,
                'deadlocks': deadlocks
            }
        except Exception as e:
            return {'status': 'FAILED', 'error': str(e)}
    
    def run_health_check(self) -> Dict[str, Any]:
        """Run all health checks"""
        if not self.connect():
            return self.health_report
        
        # Run all checks
        checks = {
            'connectivity': self.check_basic_connectivity(),
            'connections': self.check_connection_usage(),
            'disk_usage': self.check_disk_usage(),
            'cache_performance': self.check_cache_performance(),
            'long_queries': self.check_long_running_queries(),
            'table_bloat': self.check_table_bloat(),
            'replication': self.check_replication_status(),
            'security': self.check_security_events()
        }
        
        # Update health report
        self.health_report['checks'] = checks
        
        # Determine overall status
        if self.health_report['critical']:
            self.health_report['status'] = 'CRITICAL'
        elif self.health_report['warnings']:
            self.health_report['status'] = 'WARNING'
        
        # Add summary metrics
        self.health_report['metrics'] = {
            'total_checks': len(checks),
            'passed_checks': sum(1 for c in checks.values() if c.get('status') == 'OK'),
            'warning_checks': sum(1 for c in checks.values() if c.get('status') == 'WARNING'),
            'critical_checks': sum(1 for c in checks.values() if c.get('status') == 'CRITICAL'),
            'failed_checks': sum(1 for c in checks.values() if c.get('status') == 'FAILED')
        }
        
        if self.conn:
            self.conn.close()
        
        return self.health_report
    
    def format_report(self) -> str:
        """Format health report for display"""
        report = []
        report.append("=" * 70)
        report.append("PostgreSQL Health Check Report - RateRight")
        report.append("=" * 70)
        report.append(f"Timestamp: {self.health_report['timestamp']}")
        report.append(f"Overall Status: {self.health_report['status']}")
        report.append("")
        
        # Summary
        metrics = self.health_report['metrics']
        report.append("Summary:")
        report.append(f"  Total Checks: {metrics['total_checks']}")
        report.append(f"  ✅ Passed: {metrics['passed_checks']}")
        report.append(f"  ⚠️  Warnings: {metrics['warning_checks']}")
        report.append(f"  🔴 Critical: {metrics['critical_checks']}")
        report.append(f"  ❌ Failed: {metrics['failed_checks']}")
        report.append("")
        
        # Critical Issues
        if self.health_report['critical']:
            report.append("🔴 CRITICAL ISSUES:")
            for issue in self.health_report['critical']:
                report.append(f"  - {issue}")
            report.append("")
        
        # Warnings
        if self.health_report['warnings']:
            report.append("⚠️  WARNINGS:")
            for warning in self.health_report['warnings']:
                report.append(f"  - {warning}")
            report.append("")
        
        # Detailed Check Results
        report.append("Detailed Check Results:")
        report.append("-" * 50)
        
        for check_name, check_result in self.health_report['checks'].items():
            status_icon = {
                'OK': '✅',
                'WARNING': '⚠️ ',
                'CRITICAL': '🔴',
                'FAILED': '❌'
            }.get(check_result.get('status', 'UNKNOWN'), '❓')
            
            report.append(f"\n{status_icon} {check_name.upper()}: {check_result.get('status', 'UNKNOWN')}")
            
            # Add relevant details for each check
            if check_name == 'connections' and check_result.get('status') != 'FAILED':
                report.append(f"    Active Connections: {check_result['active_connections']}/{check_result['max_connections']}")
                report.append(f"    Usage: {check_result['usage_percent']:.1f}%")
            
            elif check_name == 'cache_performance' and check_result.get('status') != 'FAILED':
                report.append(f"    Cache Hit Ratio: {check_result['cache_hit_ratio']:.1f}%")
                report.append(f"    Index Cache Hit Ratio: {check_result['index_cache_hit_ratio']:.1f}%")
            
            elif check_name == 'disk_usage' and check_result.get('status') != 'FAILED':
                report.append(f"    Database Size: {check_result['database_size_pretty']}")
                if check_result.get('top_tables'):
                    report.append(f"    Largest Table: {check_result['top_tables'][0]['name']} ({check_result['top_tables'][0]['size']})")
        
        report.append("")
        report.append("=" * 70)
        
        return "\n".join(report)


def main():
    """Main execution"""
    if not DATABASE_URL:
        print("❌ ERROR: DATABASE_URL environment variable is required")
        print("Set it with: export DATABASE_URL='postgresql://user:pass@host:port/dbname'")
        sys.exit(1)
    
    print("Starting PostgreSQL health check...")
    monitor = DatabaseHealthMonitor(DATABASE_URL)
    health_report = monitor.run_health_check()
    
    # Print formatted report
    print(monitor.format_report())
    
    # Save JSON report
    with open('database_health_report.json', 'w') as f:
        json.dump(health_report, f, indent=2, default=str)
    print("\n📊 Full report saved to database_health_report.json")
    
    # Exit with appropriate code
    if health_report['status'] == 'CRITICAL':
        sys.exit(2)
    elif health_report['status'] == 'WARNING':
        sys.exit(1)
    else:
        sys.exit(0)


if __name__ == "__main__":
    main()
