#!/usr/bin/env python3
"""
Fix all migration issues for production deployment
Makes all migrations idempotent (can run multiple times safely)
"""

import os
import re

def fix_migration_file(filepath):
    """Fix a single migration file to be idempotent"""
    with open(filepath, 'r', encoding='utf-8') as f:
        content = f.read()
    
    original_content = content
    
    # Pattern to find column additions
    add_column_pattern = r'batch_op\.add_column\(sa\.Column\([\'"](\w+)[\'"]'
    
    # Pattern to find column drops
    drop_column_pattern = r'batch_op\.drop_column\([\'"](\w+)[\'"]'
    
    # Check if already has inspector code
    if 'inspector = sa.inspect(conn)' in content:
        print(f"  Already fixed: {os.path.basename(filepath)}")
        return False
    
    # Find the upgrade function
    if 'def upgrade():' in content:
        # Check if it has add_column operations
        if 'add_column' in content or 'drop_column' in content:
            # Build the new upgrade function
            new_upgrade = '''def upgrade():
    # ### commands auto generated by Alembic - please adjust! ###
    # Check existing schema before making changes
    conn = op.get_bind()
    inspector = sa.inspect(conn)
    
    # Get existing tables
    existing_tables = inspector.get_table_names()
'''
            
            # Extract table operations
            if 'with op.batch_alter_table' in content:
                # Find all batch_alter_table blocks
                table_pattern = r"with op\.batch_alter_table\('(\w+)'.*?\) as batch_op:(.*?)(?=\n\n|\Z)"
                matches = re.findall(table_pattern, content, re.DOTALL)
                
                for table_name, operations in matches:
                    new_upgrade += f'''
    if '{table_name}' in existing_tables:
        columns = [col['name'] for col in inspector.get_columns('{table_name}')]
        
        with op.batch_alter_table('{table_name}', schema=None) as batch_op:'''
                    
                    # Process each operation in the batch
                    for line in operations.split('\n'):
                        line = line.strip()
                        if 'add_column' in line:
                            # Extract column name
                            col_match = re.search(r"Column\(['\"](\w+)['\"]", line)
                            if col_match:
                                col_name = col_match.group(1)
                                # Make it conditional
                                new_upgrade += f'''
            if '{col_name}' not in columns:
                {line}'''
                        elif 'drop_column' in line:
                            # Extract column name
                            col_match = re.search(r"drop_column\(['\"](\w+)['\"]", line)
                            if col_match:
                                col_name = col_match.group(1)
                                # Make it conditional
                                new_upgrade += f'''
            if '{col_name}' in columns:
                {line}'''
                        elif line and not line.startswith('#'):
                            new_upgrade += f'''
            {line}'''
            
            new_upgrade += '''

    # ### end Alembic commands ###'''
            
            # Replace the upgrade function
            upgrade_pattern = r'def upgrade\(\):.*?(?=\ndef |$)'
            content = re.sub(upgrade_pattern, new_upgrade, content, flags=re.DOTALL)
            
            # Save the file
            with open(filepath, 'w', encoding='utf-8') as f:
                f.write(content)
            
            print(f"  Fixed: {os.path.basename(filepath)}")
            return True
    
    return False

def main():
    print("Fixing all migration files to be idempotent...")
    print("=" * 60)
    
    migrations_dir = 'migrations/versions'
    
    if not os.path.exists(migrations_dir):
        print(f"Error: {migrations_dir} directory not found")
        return
    
    fixed_count = 0
    for filename in os.listdir(migrations_dir):
        if filename.endswith('.py') and not filename.startswith('__'):
            filepath = os.path.join(migrations_dir, filename)
            if fix_migration_file(filepath):
                fixed_count += 1
    
    print("\n" + "=" * 60)
    print(f"Fixed {fixed_count} migration files")
    print("\nNext steps:")
    print("1. Review the changes")
    print("2. Commit the files")
    print("3. Deploy to Fly.io: flyctl deploy -a rateright-au")

if __name__ == '__main__':
    main()
