from mongoengine import Document, StringField, DateTimeField, IntField, DictField
from datetime import datetime
import uuid


class Case(Document):
    """Case model for legal case management"""
    meta = {'collection': 'cases'}

    # Primary fields
    case_id = StringField(required=True, unique=True, default=lambda: str(uuid.uuid4()))
    name = StringField(required=True)
    case_number = StringField(unique=True)  # Auto-generated case number
    client = StringField(required=True)  # Client name (for backward compatibility)
    client_id = StringField()  # Reference to Client model
    description = StringField()
    case_type = StringField(choices=['contract', 'litigation', 'advisory', 'compliance', 'other'])
    priority = StringField(choices=['low', 'normal', 'high', 'urgent'], default='normal')

    # Status and dates
    status = StringField(required=True, choices=['active', 'inactive', 'closed', 'pending', 'archived'], default='active')
    start_date = DateTimeField()
    end_date = DateTimeField()

    # Ownership and tracking
    user_id = StringField(required=True)  # Owner of the case
    created_by = StringField(required=True)
    updated_by = StringField()

    # Counts for documents and AI exports
    documents_count = IntField(default=0)
    ai_exports_count = IntField(default=0)

    # Additional metadata
    metadata = DictField()  # For any additional case-specific data

    # Soft delete
    is_deleted = StringField(default=False)
    deleted_at = DateTimeField()
    deleted_by = StringField()

    # Timestamps
    created_at = DateTimeField(default=datetime.utcnow)
    updated_at = DateTimeField(default=datetime.utcnow)

    def __repr__(self):
        return f'<Case {self.case_number}: {self.name}>'

    @staticmethod
    def generate_case_number():
        """Generate unique case number in format CASE-YYYY-NNNN"""
        year = datetime.utcnow().year
        # Find the last case number for this year
        last_case = Case.objects(case_number__startswith=f'CASE-{year}-').order_by('-case_number').first()

        if last_case and last_case.case_number:
            # Extract the sequence number and increment
            last_seq = int(last_case.case_number.split('-')[-1])
            new_seq = last_seq + 1
        else:
            new_seq = 1

        return f'CASE-{year}-{new_seq:04d}'

    def soft_delete(self, deleted_by_user_id):
        """Soft delete the case"""
        self.is_deleted = True
        self.deleted_at = datetime.utcnow()
        self.deleted_by = deleted_by_user_id
        self.status = 'closed'
        self.save()

    def restore(self):
        """Restore a soft-deleted case"""
        self.is_deleted = False
        self.deleted_at = None
        self.deleted_by = None
        self.status = 'active'
        self.save()
