#!/usr/bin/env python3

import argparse
import subprocess
import os
import sys
import time

def get_args():
    parser = argparse.ArgumentParser(description='a2utils certbot wrapper to generate certificates with remote validation')
    parser.add_argument('--ssh', nargs='+', help='SSH arguments, e.g.  --ssh root@example.com')
    parser.add_argument('-d','--domain', nargs='+', help='Domain name to generate certificate')
    parser.add_argument('--aliases', default=False, action='store_true')
    parser.add_argument('--altnames', metavar='CERT.pem', help='Path to remote cert to extract all altnames')
    parser.add_argument('--test-cert', default=False, action='store_true')
    parser.add_argument('--challenge-path', '-p', help='Optional path to token directory (/something/.well-known/acme-challenge/)')
    
    return parser.parse_args()



def printenv():
    for k, v in os.environ.items():
        if k.startswith('CERTBOT') or k.startswith('A2UTILS'):
            print(k,v)

def process_aliases(ssh, domain):
    cmd = ['ssh']
    cmd.extend(ssh)
    cmd.extend(['a2vhost', '-d', *domain, '--aliases'])
    out = subprocess.run(cmd, capture_output=True)
    if out.returncode:
        print("Failed to get aliases")
        sys.exit(1)
    aliases = out.stdout.decode().strip().split(' ')
    return aliases

def process_altnames(ssh, path):
    cmd = ['ssh']
    cmd.extend(ssh)
    cmd.extend(['a2certbot', '--altnames', path])
    out = subprocess.run(cmd, capture_output=True)
    if out.returncode:
        print("Failed to get altnames")
        sys.exit(1)
    aliases = out.stdout.decode().strip().split(' ')
    return aliases



def certbot_wrapper(ssh, domain, challenge_dir, test_cert):
    test_options = ['--test-cert'] if test_cert else list()
    my_env = os.environ
    my_env['A2UTILS_SSH'] = ' '.join(ssh)
    my_env['A2UTILS_CHALLENGE_DIR'] = challenge_dir or ''

    cmd = [
        'certbot', 'certonly', '--manual', *test_options,
        '--manual-auth-hook', __file__,
        '--manual-cleanup-hook', __file__,
    ]

    for d in domain:
        cmd.extend(['-d', d])
    
    subprocess.run(cmd, env=my_env)

def get_role():
    
    if 'CERTBOT_AUTH_OUTPUT' in os.environ:
        return 'cleanup'
    
    if 'CERTBOT_DOMAIN' in os.environ:
        return 'place'
    
    return 'wrapper'

def place():
    domain = os.getenv('CERTBOT_DOMAIN')
    token = os.getenv('CERTBOT_TOKEN')
    validation = os.getenv('CERTBOT_VALIDATION')
    sshcmd = os.getenv('A2UTILS_SSH').split(' ')
    challenge_dir = os.getenv('A2UTILS_CHALLENGE_DIR', '')

    cmd = [ 'ssh', *sshcmd ] + ['a2certbot', '-d', domain, '--place', token, validation ]
    if challenge_dir:
        cmd.extend(['--challenge-path', challenge_dir])
    # print("place cmd:", cmd)
    subprocess.run(cmd)
    # we need any output to have

def cleanup():
    domain = os.getenv('CERTBOT_DOMAIN')
    token = os.getenv('CERTBOT_TOKEN')
    sshcmd = os.getenv('A2UTILS_SSH').split(' ')
    challenge_dir = os.getenv('A2UTILS_CHALLENGE_DIR', '')

    cmd = [ 'ssh', *sshcmd ] + ['a2certbot', '-d', domain, '--cleanup', token ]
    if challenge_dir:
        cmd.extend([ '--challenge-path', challenge_dir ])

    # print("cleanup cmd:", cmd)
    # time.sleep(20)
    subprocess.run(cmd)

def main():
    args = get_args()


    role = get_role()
    if role == 'wrapper':

        if not args.domain and not args.altnames:
            print("Need domain(s) (-d)")
            return

        if not args.ssh:
            args.ssh = [ 'root@' + args.domain[0] ]
            print("--ssh not set, using --ssh " + ' '.join(args.ssh))

        if args.aliases:
            args.domain = process_aliases(args.ssh, args.domain)

        elif args.altnames:
            args.domain = process_altnames(args.ssh, args.altnames)


        certbot_wrapper(
            ssh=args.ssh, domain=args.domain, 
            challenge_dir=args.challenge_path,
            test_cert=args.test_cert)
    elif role == 'place':
        place()
    elif role == 'cleanup':
        cleanup()
    

if __name__ == '__main__':
    main()