from flask import jsonify, render_template, request
import logging
import traceback
from datetime import datetime

logger = loggi, timezoneng.getLogger(__name__)


class APIError(Exception):
    """Custom API error class for structured error responses"""
    
    def __init__(self, message, status_code=400, payload=None):
        super().__init__(message)
        self.message = message
        self.status_code = status_code
        self.payload = payload or {}
    
    def to_dict(self):
        return {
            'error': True,
            'message': self.message,
            'status_code': self.status_code,
            'timestamp': datetime.now(timezone.utc).isoformat(),
            **self.payload
        }


class ValidationError(APIError):
    """Validation error for form/input validation"""
    
    def __init__(self, message, field=None, errors=None):
        super().__init__(message, status_code=422)
        self.field = field
        self.errors = errors or {}
    
    def to_dict(self):
        result = super().to_dict()
        if self.field:
            result['field'] = self.field
        if self.errors:
            result['validation_errors'] = self.errors
        return result


class AuthenticationError(APIError):
    """Authentication related errors"""
    
    def __init__(self, message="Authentication required"):
        super().__init__(message, status_code=401)


class AuthorizationError(APIError):
    """Authorization related errors"""
    
    def __init__(self, message="Access denied"):
        super().__init__(message, status_code=403)


class ResourceNotFoundError(APIError):
    """Resource not found errors"""
    
    def __init__(self, resource_type="Resource", resource_id=None):
        message = f"{resource_type} not found"
        if resource_id:
            message += f" (ID: {resource_id})"
        super().__init__(message, status_code=404)


class DatabaseError(APIError):
    """Database operation errors"""
    
    def __init__(self, message="Database operation failed", original_error=None):
        super().__init__(message, status_code=500)
        self.original_error = original_error
        
        # Log the original database error for debugging
        if original_error:
            logger.error(f"Database error: {original_error}", exc_info=True)


class RateLimitError(APIError):
    """Rate limiting errors"""
    
    def __init__(self, message="Rate limit exceeded", retry_after=None):
        super().__init__(message, status_code=429)
        if retry_after:
            self.payload['retry_after'] = retry_after


def register_error_handlers(app):
    """Register all error handlers with the Flask app"""
    
    @app.errorhandler(APIError)
    def handle_api_error(error):
        """Handle custom API errors"""
        logger.warning(f"API Error: {error.message} (Status: {error.status_code})")
        
        if request.is_json or request.path.startswith('/api/'):
            return jsonify(error.to_dict()), error.status_code
        else:
            # For web requests, render error page
            return render_template('errors/error.html', 
                                 error=error.message, 
                                 status_code=error.status_code), error.status_code
    
    @app.errorhandler(ValidationError)
    def handle_validation_error(error):
        """Handle validation errors with detailed field information"""
        logger.info(f"Validation Error: {error.message} - Field: {error.field}")
        
        if request.is_json or request.path.startswith('/api/'):
            return jsonify(error.to_dict()), error.status_code
        else:
            return render_template('errors/validation_error.html', 
                                 error=error), error.status_code
    
    @app.errorhandler(400)
    def handle_bad_request(error):
        """Handle 400 Bad Request errors"""
        logger.warning(f"Bad Request: {request.url} - {error.description}")
        
        error_data = {
            'error': True,
            'message': 'Bad request - invalid data provided',
            'status_code': 400,
            'timestamp': datetime.now(timezone.utc).isoformat()
        }
        
        if request.is_json or request.path.startswith('/api/'):
            return jsonify(error_data), 400
        else:
            return render_template('errors/400.html'), 400
    
    @app.errorhandler(401)
    def handle_unauthorized(error):
        """Handle 401 Unauthorized errors"""
        logger.warning(f"Unauthorized access attempt: {request.url} - IP: {request.remote_addr}")
        
        error_data = {
            'error': True,
            'message': 'Authentication required',
            'status_code': 401,
            'timestamp': datetime.now(timezone.utc).isoformat()
        }
        
        if request.is_json or request.path.startswith('/api/'):
            return jsonify(error_data), 401
        else:
            return render_template('errors/401.html'), 401
    
    @app.errorhandler(403)
    def handle_forbidden(error):
        """Handle 403 Forbidden errors"""
        logger.warning(f"Forbidden access attempt: {request.url} - IP: {request.remote_addr}")
        
        error_data = {
            'error': True,
            'message': 'Access denied - insufficient permissions',
            'status_code': 403,
            'timestamp': datetime.now(timezone.utc).isoformat()
        }
        
        if request.is_json or request.path.startswith('/api/'):
            return jsonify(error_data), 403
        else:
            return render_template('errors/403.html'), 403
    
    @app.errorhandler(404)
    def handle_not_found(error):
        """Handle 404 Not Found errors"""
        logger.info(f"Page not found: {request.url}")
        
        error_data = {
            'error': True,
            'message': 'Resource not found',
            'status_code': 404,
            'timestamp': datetime.now(timezone.utc).isoformat()
        }
        
        if request.is_json or request.path.startswith('/api/'):
            return jsonify(error_data), 404
        else:
            return render_template('errors/404.html'), 404
    
    @app.errorhandler(405)
    def handle_method_not_allowed(error):
        """Handle 405 Method Not Allowed errors"""
        logger.warning(f"Method not allowed: {request.method} {request.url}")
        
        error_data = {
            'error': True,
            'message': f'Method {request.method} not allowed for this endpoint',
            'status_code': 405,
            'allowed_methods': error.valid_methods if hasattr(error, 'valid_methods') else [],
            'timestamp': datetime.now(timezone.utc).isoformat()
        }
        
        if request.is_json or request.path.startswith('/api/'):
            return jsonify(error_data), 405
        else:
            return render_template('errors/405.html', 
                                 allowed_methods=error_data['allowed_methods']), 405
    
    @app.errorhandler(413)
    def handle_payload_too_large(error):
        """Handle 413 Payload Too Large errors"""
        logger.warning(f"Payload too large: {request.url} - Size: {request.content_length}")
        
        error_data = {
            'error': True,
            'message': 'File or request too large',
            'status_code': 413,
            'max_size': app.config.get('MAX_CONTENT_LENGTH', 16 * 1024 * 1024),
            'timestamp': datetime.now(timezone.utc).isoformat()
        }
        
        if request.is_json or request.path.startswith('/api/'):
            return jsonify(error_data), 413
        else:
            return render_template('errors/413.html', 
                                 max_size=error_data['max_size']), 413
    
    @app.errorhandler(429)
    def handle_rate_limit_exceeded(error):
        """Handle 429 Rate Limit Exceeded errors"""
        logger.warning(f"Rate limit exceeded: {request.url} - IP: {request.remote_addr}")
        
        retry_after = getattr(error, 'retry_after', 60)
        
        error_data = {
            'error': True,
            'message': 'Rate limit exceeded - too many requests',
            'status_code': 429,
            'retry_after': retry_after,
            'timestamp': datetime.now(timezone.utc).isoformat()
        }
        
        response = jsonify(error_data) if (request.is_json or request.path.startswith('/api/')) else render_template('errors/429.html', retry_after=retry_after)
        
        if hasattr(response, 'headers'):
            response.headers['Retry-After'] = str(retry_after)
        
        return response, 429
    
    @app.errorhandler(500)
    def handle_internal_server_error(error):
        """Handle 500 Internal Server Error"""
        # Log the full traceback for debugging
        logger.error(f"Internal Server Error: {request.url}", exc_info=True)
        
        # In production, don't expose internal error details
        if app.config.get('DEBUG'):
            error_message = str(error)
            traceback_info = traceback.format_exc()
        else:
            error_message = 'An internal server error occurred'
            traceback_info = None
        
        error_data = {
            'error': True,
            'message': error_message,
            'status_code': 500,
            'timestamp': datetime.now(timezone.utc).isoformat()
        }
        
        if app.config.get('DEBUG') and traceback_info:
            error_data['traceback'] = traceback_info
        
        if request.is_json or request.path.startswith('/api/'):
            return jsonify(error_data), 500
        else:
            return render_template('errors/500.html', 
                                 error_message=error_message,
                                 debug=app.config.get('DEBUG')), 500
    
    @app.errorhandler(502)
    def handle_bad_gateway(error):
        """Handle 502 Bad Gateway errors"""
        logger.error(f"Bad Gateway: {request.url}")
        
        error_data = {
            'error': True,
            'message': 'Bad gateway - upstream server error',
            'status_code': 502,
            'timestamp': datetime.now(timezone.utc).isoformat()
        }
        
        if request.is_json or request.path.startswith('/api/'):
            return jsonify(error_data), 502
        else:
            return render_template('errors/502.html'), 502
    
    @app.errorhandler(503)
    def handle_service_unavailable(error):
        """Handle 503 Service Unavailable errors"""
        logger.error(f"Service Unavailable: {request.url}")
        
        error_data = {
            'error': True,
            'message': 'Service temporarily unavailable',
            'status_code': 503,
            'timestamp': datetime.now(timezone.utc).isoformat()
        }
        
        if request.is_json or request.path.startswith('/api/'):
            return jsonify(error_data), 503
        else:
            return render_template('errors/503.html'), 503
    
    # Handle SQLAlchemy database errors
    try:
        from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError
        
        @app.errorhandler(IntegrityError)
        def handle_integrity_error(error):
            """Handle database integrity constraint violations"""
            logger.error(f"Database Integrity Error: {error}", exc_info=True)
            
            # Try to provide user-friendly error messages
            error_message = "Data integrity error"
            if "UNIQUE constraint failed" in str(error):
                error_message = "This record already exists"
            elif "NOT NULL constraint failed" in str(error):
                error_message = "Required field is missing"
            elif "FOREIGN KEY constraint failed" in str(error):
                error_message = "Referenced record does not exist"
            
            error_data = {
                'error': True,
                'message': error_message,
                'status_code': 400,
                'type': 'database_integrity_error',
                'timestamp': datetime.now(timezone.utc).isoformat()
            }
            
            if request.is_json or request.path.startswith('/api/'):
                return jsonify(error_data), 400
            else:
                return render_template('errors/database_error.html', 
                                     error_message=error_message), 400
        
        @app.errorhandler(OperationalError)
        def handle_operational_error(error):
            """Handle database operational errors"""
            logger.error(f"Database Operational Error: {error}", exc_info=True)
            
            error_data = {
                'error': True,
                'message': 'Database connection error - please try again',
                'status_code': 503,
                'type': 'database_operational_error',
                'timestamp': datetime.now(timezone.utc).isoformat()
            }
            
            if request.is_json or request.path.startswith('/api/'):
                return jsonify(error_data), 503
            else:
                return render_template('errors/database_error.html', 
                                     error_message=error_data['message']), 503
        
        @app.errorhandler(SQLAlchemyError)
        def handle_sqlalchemy_error(error):
            """Handle general SQLAlchemy errors"""
            logger.error(f"SQLAlchemy Error: {error}", exc_info=True)
            
            error_data = {
                'error': True,
                'message': 'Database error occurred',
                'status_code': 500,
                'type': 'database_error',
                'timestamp': datetime.now(timezone.utc).isoformat()
            }
            
            if request.is_json or request.path.startswith('/api/'):
                return jsonify(error_data), 500
            else:
                return render_template('errors/database_error.html', 
                                     error_message=error_data['message']), 500
    
    except ImportError:
        logger.warning("SQLAlchemy not available for error handling")
    
    # Handle JWT errors
    try:
        from flask_jwt_extended.exceptions import JWTExtendedException
        
        @app.errorhandler(JWTExtendedException)
        def handle_jwt_exceptions(error):
            """Handle JWT related errors"""
            logger.warning(f"JWT Error: {error}")
            
            error_data = {
                'error': True,
                'message': 'Authentication token error',
                'status_code': 401,
                'type': 'jwt_error',
                'timestamp': datetime.now(timezone.utc).isoformat()
            }
            
            return jsonify(error_data), 401
    
    except ImportError:
        logger.warning("Flask-JWT-Extended not available for error handling")
    
    # Generic exception handler for unhandled exceptions
    @app.errorhandler(Exception)
    def handle_generic_exception(error):
        """Handle any unhandled exceptions"""
        logger.error(f"Unhandled Exception: {error}", exc_info=True)
        
        # Don't expose internal error details in production
        if app.config.get('DEBUG'):
            error_message = str(error)
            error_type = type(error).__name__
        else:
            error_message = 'An unexpected error occurred'
            error_type = 'internal_error'
        
        error_data = {
            'error': True,
            'message': error_message,
            'status_code': 500,
            'type': error_type,
            'timestamp': datetime.now(timezone.utc).isoformat()
        }
        
        if request.is_json or request.path.startswith('/api/'):
            return jsonify(error_data), 500
        else:
            return render_template('errors/500.html', 
                                 error_message=error_message,
                                 debug=app.config.get('DEBUG')), 500


def log_slow_queries(app):
    """Log slow database queries for performance monitoring"""
    if not app.config.get('SQLALCHEMY_RECORD_QUERIES'):
        return
    
    from flask_sqlalchemy import get_debug_queries
    
    @app.after_request
    def after_request(response):
        queries = get_debug_queries()
        slow_threshold = app.config.get('SLOW_QUERY_THRESHOLD', 0.5)
        
        for query in queries:
            if query.duration >= slow_threshold:
                logger.warning(f"Slow query ({query.duration:.3f}s): {query.statement}")
        
        return response


def setup_request_logging(app):
    """Set up request/response logging for monitoring"""
    
    @app.before_request
    def log_request_info():
        """Log incoming request information"""
        if not request.path.startswith('/static/'):
            logger.info(f"Request: {request.method} {request.path} - IP: {request.remote_addr}")
    
    @app.after_request
    def log_response_info(response):
        """Log response information"""
        if not request.path.startswith('/static/'):
            logger.info(f"Response: {response.status_code} - {request.method} {request.path}")
        return response


# Utility functions for error handling
def safe_execute(func, *args, **kwargs):
    """Safely execute a function with error handling"""
    try:
        return func(*args, **kwargs)
    except Exception as e:
        logger.error(f"Error executing {func.__name__}: {e}", exc_info=True)
        raise DatabaseError(f"Operation failed: {func.__name__}", original_error=e)


def validate_required_fields(data, required_fields):
    """Validate that required fields are present in data"""
    missing_fields = []
    errors = {}
    
    for field in required_fields:
        if field not in data or not data[field]:
            missing_fields.append(field)
            errors[field] = f"{field} is required"
    
    if missing_fields:
        raise ValidationError(
            f"Missing required fields: {', '.join(missing_fields)}", 
            errors=errors
        )


def validate_email(email):
    """Validate email format"""
    import re
    pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
    if not re.match(pattern, email):
        raise ValidationError("Invalid email format", field="email")


def validate_abn(abn):
    """Validate Australian Business Number format"""
    if not abn or len(abn) != 11 or not abn.isdigit():
        raise ValidationError("ABN must be 11 digits", field="abn_number")
