#!/usr/bin/python3

# Copyright Security Onion Solutions LLC and/or licensed to Security Onion Solutions LLC under one
# or more contributor license agreements. Licensed under the Elastic License 2.0 as shown at
# https://securityonion.net/license; you may not use this file except in compliance with the
# Elastic License 2.0.

"""
Script for waiting for cloud-init to complete on a Security Onion VM.
Monitors VM state to ensure proper cloud-init initialization and shutdown.

**Usage:**
    so-wait-cloud-init -n <domain_name>

**Options:**
    -n, --name          Domain name of the VM to monitor

**Exit Codes:**
- 0: Success (cloud-init completed and VM shutdown)
- 1: General error
- 2: VM never started
- 3: VM stopped too quickly
- 4: VM failed to shutdown

**Description:**
This script monitors a VM's state to ensure proper cloud-init initialization and completion:
1. Waits for VM to start running
2. Verifies VM remains running (not an immediate crash)
3. Waits for VM to shutdown (indicating cloud-init completion)
4. Verifies VM remains shutdown

The script is typically used in the libvirt.images state after creating a new VM
to ensure cloud-init completes its initialization before proceeding with further
configuration.

**Logging:**
- Logs are written to /opt/so/log/hypervisor/so-wait-cloud-init.log
- Both file and console logging are enabled
- Log entries include:
  - Timestamps
  - State changes
  - Error conditions
  - Verification steps
"""

import argparse
import logging
import subprocess
import sys
import time
from so_logging_utils import setup_logging

# Set up logging
logger = setup_logging(
    logger_name='so-wait-cloud-init',
    log_file_path='/opt/so/log/hypervisor/so-wait-cloud-init.log',
    log_level=logging.INFO,
    format_str='%(asctime)s - %(levelname)s - %(message)s'
)

def check_vm_running(domain_name):
    """
    Check if VM is in running state.
    
    Args:
        domain_name (str): Name of the domain to check
        
    Returns:
        bool: True if VM is running, False otherwise
    """
    try:
        result = subprocess.run(['virsh', 'list', '--state-running', '--name'],
                              capture_output=True, text=True, check=True)
        return domain_name in result.stdout.splitlines()
    except subprocess.CalledProcessError as e:
        logger.error(f"Failed to check VM state: {e}")
        return False

def wait_for_vm_start(domain_name, timeout=300):
    """
    Wait for VM to start running.
    
    Args:
        domain_name (str): Name of the domain to monitor
        timeout (int): Maximum time to wait in seconds
        
    Returns:
        bool: True if VM started, False if timeout occurred
    """
    logger.info(f"Waiting for VM {domain_name} to start...")
    start_time = time.time()
    
    while time.time() - start_time < timeout:
        if check_vm_running(domain_name):
            logger.info("VM is running")
            return True
        time.sleep(1)
    
    logger.error(f"Timeout waiting for VM {domain_name} to start")
    return False

def verify_vm_running(domain_name):
    """
    Verify VM remains running after initial start.
    
    Args:
        domain_name (str): Name of the domain to verify
        
    Returns:
        bool: True if VM is still running after verification period
    """
    logger.info("Verifying VM remains running...")
    time.sleep(5)  # Wait to ensure VM is stable
    
    if not check_vm_running(domain_name):
        logger.error("VM stopped too quickly after starting")
        return False
    
    logger.info("VM verified running")
    return True

def wait_for_vm_shutdown(domain_name, timeout=600):
    """
    Wait for VM to shutdown.
    
    Args:
        domain_name (str): Name of the domain to monitor
        timeout (int): Maximum time to wait in seconds
        
    Returns:
        bool: True if VM shutdown, False if timeout occurred
    """
    logger.info("Waiting for cloud-init to complete and VM to shutdown...")
    start_time = time.time()
    check_count = 0
    
    while time.time() - start_time < timeout:
        if not check_vm_running(domain_name):
            logger.info("VM has shutdown")
            return True
            
        # Log status every minute (after 12 checks at 5 second intervals)
        check_count += 1
        if check_count % 12 == 0:
            elapsed = int(time.time() - start_time)
            logger.info(f"Still waiting for cloud-init... ({elapsed} seconds elapsed)")
        
        time.sleep(5)
    
    logger.error(f"Timeout waiting for VM {domain_name} to shutdown")
    return False

def verify_vm_shutdown(domain_name):
    """
    Verify VM remains shutdown.
    
    Args:
        domain_name (str): Name of the domain to verify
        
    Returns:
        bool: True if VM remains shutdown after verification period
    """
    logger.info("Verifying VM remains shutdown...")
    time.sleep(5)  # Wait to ensure VM state is stable
    
    if check_vm_running(domain_name):
        logger.error("VM is still running after shutdown check")
        return False
    
    logger.info("VM verified shutdown")
    return True

def main():
    parser = argparse.ArgumentParser(
        description="Wait for cloud-init to complete on a Security Onion VM"
    )
    parser.add_argument("-n", "--name", required=True,
                      help="Domain name of the VM to monitor")
    args = parser.parse_args()

    try:
        # Wait for VM to start
        if not wait_for_vm_start(args.name):
            sys.exit(2)  # VM never started
        
        # Verify VM remains running
        if not verify_vm_running(args.name):
            sys.exit(3)  # VM stopped too quickly
        
        # Wait for VM to shutdown
        if not wait_for_vm_shutdown(args.name):
            sys.exit(4)  # VM failed to shutdown
        
        # Verify VM remains shutdown
        if not verify_vm_shutdown(args.name):
            sys.exit(4)  # VM failed to stay shutdown
        
        logger.info("Cloud-init completed successfully")
        sys.exit(0)
        
    except Exception as e:
        logger.error(f"Unexpected error: {e}")
        sys.exit(1)

if __name__ == "__main__":
    main()
