from datetime import datetime, date, timedelta
from ..extensions import db
from .base import BaseModel


class Contract(BaseModel):
    """Independent contractor agreements (Fair Work Act compliant)"""
    __tablename__ = 'contracts'
    
    # Contract Parties
    job_id = db.Column(db.Integer, db.ForeignKey('jobs.id'), nullable=False)
    contractor_id = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=False)
    worker_id = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=False)
    
    # Name snapshots (preserved even after user deletion for legal/historical records)
    contractor_name = db.Column(db.String(200), nullable=True)  # "FirstName LastName" at time of contract
    worker_name = db.Column(db.String(200), nullable=True)  # "FirstName LastName" at time of contract
    
    # Contact info snapshots (for legal/compliance/tax records)
    contractor_email = db.Column(db.String(120), nullable=True)
    contractor_phone = db.Column(db.String(20), nullable=True)
    contractor_abn = db.Column(db.String(11), nullable=True)
    contractor_business_name = db.Column(db.String(200), nullable=True)
    
    worker_email = db.Column(db.String(120), nullable=True)
    worker_phone = db.Column(db.String(20), nullable=True)
    worker_abn = db.Column(db.String(11), nullable=True)
    worker_business_name = db.Column(db.String(200), nullable=True)
    
    # Contract Terms
    agreed_rate = db.Column(db.Numeric(10, 2), nullable=False)
    rate_type = db.Column(db.String(20), default='total', nullable=False)  # total, hourly, daily
    start_date = db.Column(db.Date, nullable=False)
    end_date = db.Column(db.Date, nullable=False)
    scope_of_work = db.Column(db.Text, nullable=False)
    
    # Australian Legal Requirements
    independent_contractor_status = db.Column(db.Boolean, default=True, nullable=False)
    superannuation_required = db.Column(db.Boolean, default=False, nullable=False)
    workers_comp_covered = db.Column(db.Boolean, default=False, nullable=False)
    
    # Contract Status
    status = db.Column(db.String(20), default='pending_agreement', nullable=False)
    completion_status = db.Column(db.String(50), default="not_started", nullable=False)
    payment_status = db.Column(db.String(50), default="pending", nullable=False)
    contractor_approval_date = db.Column(db.DateTime)
    worker_completion_date = db.Column(db.DateTime)
    review_hours_status = db.Column(db.String(50), default='pending', nullable=False)
    

    # Status: 'pending_agreement', 'contractor_signed', 'worker_signed', 'active', 'pending_review', 'disputed', 'completed', 'cancelled'
    contractor_signed = db.Column(db.Boolean, default=False, nullable=False)
    worker_signed = db.Column(db.Boolean, default=False, nullable=False)
    contractor_signed_date = db.Column(db.DateTime)
    worker_signed_date = db.Column(db.DateTime)

    # Dispute Resolution
    worker_agreed_resolve = db.Column(db.Boolean, default=False, nullable=False)
    contractor_agreed_resolve = db.Column(db.Boolean, default=False, nullable=False)
    
    # Contract negotiation
    contractor_reviewed = db.Column(db.Boolean, default=False, nullable=False)
    worker_reviewed = db.Column(db.Boolean, default=False, nullable=False)
    last_modified_by = db.Column(db.Integer, db.ForeignKey('users.id'))
    
    # Rating completion tracking
    contractor_rated = db.Column(db.Boolean, default=False, nullable=False)
    worker_rated = db.Column(db.Boolean, default=False, nullable=False)
    mutual_rating_completed_date = db.Column(db.DateTime)
    
    # Payment Terms
    payment_terms = db.Column(db.String(50), default='completion', nullable=False)
    # Terms: 'completion', 'milestone', 'weekly', 'fortnightly'
    payment_schedule = db.Column(db.Text)  # JSON for milestone payments
    
    # Relationships
    job = db.relationship('Job', backref='contracts')
    contractor = db.relationship('User', foreign_keys=[contractor_id], backref='contractor_contracts')
    worker = db.relationship('User', foreign_keys=[worker_id], backref='worker_contracts')
    payments = db.relationship('Payment', backref='contract', lazy='dynamic', cascade='all, delete-orphan')
    reviews = db.relationship('Review', backref='contract', lazy='dynamic', cascade='all, delete-orphan')
    
    def get_contractor_name(self):
        """Get contractor name - uses snapshot if available, falls back to user object"""
        if self.contractor_name:
            return self.contractor_name
        if self.contractor:
            return f"{self.contractor.first_name} {self.contractor.last_name}"
        return "Unknown Contractor"
    
    def get_worker_name(self):
        """Get worker name - uses snapshot if available, falls back to user object"""
        if self.worker_name:
            return self.worker_name
        if self.worker:
            return f"{self.worker.first_name} {self.worker.last_name}"
        return "Unknown Worker"
    
    def get_contractor_email(self):
        """Get contractor email - uses snapshot if available, falls back to user object"""
        if self.contractor_email:
            return self.contractor_email
        if self.contractor:
            return self.contractor.email
        return None
    
    def get_worker_email(self):
        """Get worker email - uses snapshot if available, falls back to user object"""
        if self.worker_email:
            return self.worker_email
        if self.worker:
            return self.worker.email
        return None
    
    def get_contractor_phone(self):
        """Get contractor phone - uses snapshot if available, falls back to user object"""
        if self.contractor_phone:
            return self.contractor_phone
        if self.contractor:
            return self.contractor.phone_number
        return None
    
    def get_worker_phone(self):
        """Get worker phone - uses snapshot if available, falls back to user object"""
        if self.worker_phone:
            return self.worker_phone
        if self.worker:
            return self.worker.phone_number
        return None
    
    def get_contractor_abn(self):
        """Get contractor ABN - uses snapshot if available, falls back to user object"""
        if self.contractor_abn:
            return self.contractor_abn
        if self.contractor:
            return self.contractor.abn_number
        return None
    
    def get_worker_abn(self):
        """Get worker ABN - uses snapshot if available, falls back to user object"""
        if self.worker_abn:
            return self.worker_abn
        if self.worker:
            return self.worker.abn_number
        return None
    
    def get_contractor_business_name(self):
        """Get contractor business name - uses snapshot if available, falls back to user object"""
        if self.contractor_business_name:
            return self.contractor_business_name
        if self.contractor:
            return self.contractor.business_name
        return None
    
    def get_worker_business_name(self):
        """Get worker business name - uses snapshot if available, falls back to user object"""
        if self.worker_business_name:
            return self.worker_business_name
        if self.worker:
            return self.worker.business_name
        return None
    
    def is_fully_signed(self):
        """Check if both parties have signed"""
        return self.contractor_signed and self.worker_signed
    
    def get_next_step(self):
        """Get what needs to happen next in the contract workflow"""
        if self.status == 'pending_agreement':
            if not self.contractor_reviewed:
                return 'contractor_review'
            elif not self.worker_reviewed:
                return 'worker_review'
            else:
                return 'ready_for_signing'
        elif self.status == 'contractor_signed':
            return 'waiting_worker_signature'
        elif self.status == 'worker_signed':
            return 'waiting_contractor_signature'
        elif self.status == 'active':
            return 'work_in_progress'
        return self.status
    
    def can_sign(self, user_id):
        """Check if user can sign this contract"""
        if self.status not in ['pending_agreement', 'contractor_signed', 'worker_signed']:
            return False
        
        if user_id == self.contractor_id and not self.contractor_signed:
            return True
        elif user_id == self.worker_id and not self.worker_signed:
            return True
        
        return False
    
    def sign_contract(self, user_id):
        """Sign the contract and update status"""
        if not self.can_sign(user_id):
            return False

        # Record who signed
        if user_id == self.contractor_id:
            self.contractor_signed = True
            self.contractor_signed_date = datetime.utcnow()
        elif user_id == self.worker_id:
            self.worker_signed = True
            self.worker_signed_date = datetime.utcnow()

        # ✅ Updated logic:
        if self.contractor_signed and self.worker_signed:
            # Both signed, but payment not made yet → set to 'signed'
            self.status = 'signed'
            self.payment_status = 'pending'
        elif self.contractor_signed and not self.worker_signed:
            self.status = 'contractor_signed'
        elif self.worker_signed and not self.contractor_signed:
            self.status = 'worker_signed'

        return True

        
        return True
    
    def calculate_total_value(self):
        """Calculate total contract value including GST if applicable"""
        base_amount = float(self.agreed_rate)
        
        # Add GST if worker is GST registered
        if self.worker.gst_registered:
            gst_amount = base_amount * 0.10
            return base_amount + gst_amount
        
        return base_amount
    
    def get_contract_type(self):
        """Determine contract type for legal compliance"""
        # Simplified logic - in production would use more complex rules
        if self.agreed_rate > 20000 or (self.end_date - self.start_date).days > 30:
            return 'independent_contractor'
        return 'casual_contractor'
    
    def is_mutual_rating_complete(self):
        """Check if both parties have completed their ratings"""
        return self.contractor_rated and self.worker_rated
    
    def get_rating_status_for_user(self, user_id):
        """Get rating status specific to a user"""
        if user_id == self.contractor_id:
            return 'completed' if self.contractor_rated else 'pending'
        elif user_id == self.worker_id:
            return 'completed' if self.worker_rated else 'pending'
        return 'not_authorized'
    
    def get_pending_ratings_summary(self):
        """Get summary of who still needs to rate"""
        pending = []
        if not self.contractor_rated:
            pending.append('contractor')
        if not self.worker_rated:
            pending.append('worker')
        return pending
    
    def mark_as_completed(self):
        from ..extensions import db
        from datetime import datetime

        self.status = 'completed'
        self.completion_status = 'completed'
        self.worker_completion_date = datetime.utcnow()

        # ✅ Safely increment job counts
        if self.worker:
            self.worker.jobs_completed = (self.worker.jobs_completed or 0) + 1
            db.session.add(self.worker)

        if self.contractor:
            self.contractor.jobs_completed = (self.contractor.jobs_completed or 0) + 1
            db.session.add(self.contractor)

        # ✅ Also make sure the contract itself is part of the session
        db.session.add(self)
        db.session.commit()

    
    def __repr__(self):
        return f'<Contract {self.id} - {self.status}>'


class Payment(BaseModel):
    """Escrow payment system for construction contracts"""
    __tablename__ = 'payments'
    
    # Payment Reference
    contract_id = db.Column(db.Integer, db.ForeignKey('contracts.id'), nullable=False)
    payment_reference = db.Column(db.String(50), unique=True, nullable=False)
    
    # Stripe Integration
    stripe_payment_intent_id = db.Column(db.String(255), unique=True)
    stripe_client_secret = db.Column(db.String(255))
    stripe_status = db.Column(db.String(50))
    stripe_transfer_id = db.Column(db.String(255), unique=True)  # Transfer to worker
    stripe_transfer_status = db.Column(db.String(50))  # Transfer status
    
    # Payment Amounts (all in AUD)
    gross_amount = db.Column(db.Numeric(10, 2), nullable=False)
    platform_fee = db.Column(db.Numeric(10, 2), nullable=False)
    gst_amount = db.Column(db.Numeric(10, 2), default=0, nullable=False)
    net_to_worker = db.Column(db.Numeric(10, 2), nullable=False)
    
    # Withholding Tax (for non-GST registered contractors)
    withholding_tax_rate = db.Column(db.Numeric(5, 4), default=0, nullable=False)
    withholding_tax_amount = db.Column(db.Numeric(10, 2), default=0, nullable=False)
    
    # Payment Status
    status = db.Column(db.String(20), default='pending', nullable=False)
    # Status: 'pending', 'held_escrow', 'released', 'transferred', 'refunded', 'disputed'
    
    # Payment Timeline
    date_initiated = db.Column(db.DateTime, default=datetime.utcnow, nullable=False)
    date_held_escrow = db.Column(db.DateTime)
    date_released = db.Column(db.DateTime)
    release_conditions_met = db.Column(db.Boolean, default=False, nullable=False)
    
    # Dispute Protection
    dispute_period_days = db.Column(db.Integer, default=7, nullable=False)
    dispute_deadline = db.Column(db.DateTime)
    
    # Relationships
    invoices = db.relationship('Invoice', backref='payment', lazy='dynamic', cascade='all, delete-orphan')
    
    def calculate_amounts(self):
        """Calculate all payment amounts based on contract"""
        contract = self.contract
        
        # Calculate total based on hourly rate × hours
        # Try to parse duration as numeric hours first (new format)
        try:
            hours = int(contract.job.duration)
        except (ValueError, TypeError):
            # Legacy format - use mapping
            # Support both old and new duration formats
            duration_hours = {
                # New format (with underscores)
                'half_day': 4,
                'full_day': 8,
                'long_day': 10,
                '2_days': 16,
                '3_5_days': 32,
                '1_week_plus': 40,
                # Old format (legacy - for backwards compatibility)
                '4 hours': 4,
                '8 hours': 8,
                '10 hours': 10,
                '2 days': 16,
                '3-5 days': 32,
                '1 week+': 40
            }
            hours = duration_hours.get(contract.job.duration, 8)
        
        base_amount = float(contract.agreed_rate) * hours
        
        # Check if worker is GST registered
        worker = contract.worker
        is_gst_registered = worker.gst_registered if worker else False
        
        # Calculate GST if worker is registered (10% in Australia)
        if is_gst_registered:
            self.gst_amount = base_amount * 0.10
        else:
            self.gst_amount = 0
        
        # Platform fee - check for referral match
        platform_fee_rate = 0.07  # Default 7%
        
        # If worker and contractor share the same referral code, apply 3% discount
        if (contract.worker and contract.contractor and 
            contract.worker.referral_code and contract.contractor.referral_code and
            contract.worker.referral_code == contract.contractor.referral_code):
            platform_fee_rate = 0.03  # Referral discount: 3%
        
        self.platform_fee = base_amount * platform_fee_rate
        
        # Stripe processing fee (2.9% + $0.30) - deducted from worker's payment
        # Calculate on gross_amount (base + GST if applicable)
        gross_with_gst = base_amount + self.gst_amount
        stripe_fee = (gross_with_gst * 0.029) + 0.30
        
        # Initialize withholding tax (not applicable for GST registered workers)
        self.withholding_tax_amount = 0
        self.withholding_tax_rate = 0
        
        # Gross amount contractor pays = base amount + GST (if worker is GST registered)
        self.gross_amount = gross_with_gst
        
        # Net amount to worker = gross (base + GST) - platform fee - stripe fee
        # Worker receives the GST amount (which they must remit to ATO)
        # Platform fee and Stripe fee are calculated on the base amount only
        self.net_to_worker = gross_with_gst - self.platform_fee - stripe_fee
    
    def can_be_released(self):
        """Check if payment can be released from escrow"""
        if self.status != 'held_escrow':
            return False
        
        # Check if dispute period has passed
        if self.dispute_deadline and datetime.utcnow() < self.dispute_deadline:
            return False
        
        return self.release_conditions_met
    
    def auto_generate_invoice(self):
        """Automatically generate invoice when payment is released"""
        from app.models.user import User
        
        contract = self.contract
        worker = contract.worker
        contractor = contract.contractor
        
        # Check if invoice already exists for this payment
        existing_invoice = Invoice.query.filter_by(payment_id=self.id).first()
        if existing_invoice:
            return existing_invoice
        
        # Check if worker is GST registered (only GST workers need invoices)
        if not worker.gst_registered:
            return None
        
        # Calculate amounts from payment (already has correct GST calculation)
        amount_ex_gst = float(self.gross_amount) - float(self.gst_amount)
        gst_rate = 0.10 if worker.gst_registered else 0.0
        
        # Create invoice
        invoice = Invoice(
            payment_id=self.id,
            contractor_id=contractor.id,
            description=f"Construction services for {contract.job.title}",
            invoice_date=date.today(),
            due_date=date.today() + timedelta(days=30),
            amount_ex_gst=amount_ex_gst,
            gst_amount=float(self.gst_amount),
            total_amount=float(self.gross_amount),
            gst_rate=gst_rate,
            supplier_abn=worker.abn_number or '00000000000',
            supplier_name=worker.business_name or f"{worker.first_name} {worker.last_name}",
            buyer_abn=contractor.abn_number or '00000000000',
            buyer_name=contractor.business_name or f"{contractor.first_name} {contractor.last_name}",
            status='sent'  # Auto-generated invoices are sent immediately
        )
        
        # Generate unique invoice number
        invoice.generate_invoice_number()
        
        # Validate GST compliance
        is_compliant, compliance_errors = invoice.validate_gst_compliance()
        if not is_compliant:
            # Log warning but still create invoice
            print(f"⚠️ Invoice GST compliance warning: {compliance_errors}")
        
        db.session.add(invoice)
        return invoice
    
    def __repr__(self):
        return f'<Payment {self.payment_reference} - {self.status}>'


class Invoice(BaseModel):
    """GST-compliant invoices for Australian tax law"""
    __tablename__ = 'invoices'
    
    # Invoice Reference
    payment_id = db.Column(db.Integer, db.ForeignKey('payments.id'), nullable=False)
    contractor_id = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=False)
    invoice_number = db.Column(db.String(50), unique=True, nullable=False)
    
    # Invoice Details
    invoice_date = db.Column(db.Date, default=date.today, nullable=False)
    due_date = db.Column(db.Date, nullable=False)
    description = db.Column(db.Text, nullable=False)
    
    # GST Compliance (Australian Tax Office requirements)
    amount_ex_gst = db.Column(db.Numeric(10, 2), nullable=False)
    gst_amount = db.Column(db.Numeric(10, 2), nullable=False)
    total_amount = db.Column(db.Numeric(10, 2), nullable=False)
    gst_rate = db.Column(db.Numeric(5, 4), default=0.10, nullable=False)  # 10% GST
    
    # Business Details (for GST compliance)
    supplier_abn = db.Column(db.String(11), nullable=False)
    supplier_name = db.Column(db.String(200), nullable=False)
    buyer_abn = db.Column(db.String(11), nullable=False)
    buyer_name = db.Column(db.String(200), nullable=False)
    
    # Invoice Status
    status = db.Column(db.String(20), default='draft', nullable=False)
    # Status: 'draft', 'sent', 'paid', 'overdue', 'cancelled'
    
    def generate_invoice_number(self):
        """Generate unique invoice number"""
        import uuid
        year = date.today().year
        short_uuid = str(uuid.uuid4())[:8].upper()
        self.invoice_number = f"RR{year}-{short_uuid}"
    
    def validate_gst_compliance(self):
        """Validate invoice meets Australian GST requirements"""
        errors = []
        
        # ABN validation
        if not self.supplier_abn or len(self.supplier_abn) != 11:
            errors.append("Valid supplier ABN required")
        
        if not self.buyer_abn or len(self.buyer_abn) != 11:
            errors.append("Valid buyer ABN required")
        
        # Amount validation
        expected_gst = float(self.amount_ex_gst) * float(self.gst_rate)
        if abs(float(self.gst_amount) - expected_gst) > 0.01:
            errors.append("GST amount calculation error")
        
        expected_total = float(self.amount_ex_gst) + float(self.gst_amount)
        if abs(float(self.total_amount) - expected_total) > 0.01:
            errors.append("Total amount calculation error")
        
        return len(errors) == 0, errors
    
    def __repr__(self):
        return f'<Invoice {self.invoice_number} - ${self.total_amount}>'
