#!/usr/bin/env python3

import a2conf
import argparse
import logging
import requests
import os
import socket
import random
import string
import sys
import subprocess
import pathlib
import ssl

from a2conf import Node, VhostNotFound

log = None


class FatalError(Exception):
    pass

class LetsEncryptCertificateConfig:
    def __init__(self, path, webroot=None, domains=None):
        self.content = dict()
        if webroot:
            self.init_webroot(webroot, domains)
        else:
            self.init_readfile(path)

    def init_webroot(self, webroot, domains):
        self.path = '::internal::'
        self.content = dict()
        self.content['[[webroot_map]]'] = dict()
        for domain in domains:
            self.content['[[webroot_map]]'][domain] = webroot

    def init_readfile(self, path):
        self.path = path
        self.content = dict()
        self.content[''] = dict()
        section = ''

        with open(path) as f:
            for line in f:
                line = line.strip()
                if not line or line.startswith('#'):
                    # skip empty lines
                    continue

                if line.startswith('['):
                    # new section
                    section = line
                    self.content[section] = dict()
                else:
                    k, v = line.split('=')
                    k = k.strip()
                    v = v.strip()
                    self.content[section][k] = v

    @property
    def domains(self):
        try:
            return self.content['[[webroot_map]]'].keys()
        except KeyError:
            print("No [[webroot_map]] in {}".format(self.path))
            raise

    def get_droot(self, domain):
        return self.content['[[webroot_map]]'][domain]

    def dump(self):
        print(self.content)


class Report:
    def __init__(self, name):
        self.name = name
        self._info = list()
        self._problem = list()
        self.prefix = ' ' * 4
        self.objects = dict()

    def info(self, msg, object=None):
        if object:
            if msg not in self.objects:
                self.objects[msg] = list()
            self.objects[msg].append(object)
        else:
            self._info.append(msg)

    def problem(self, msg):
        self._problem.append(msg)

    def has_problems(self):
        return bool(len(self._problem))

    def report(self):
        if self._problem:
            print("=== {} PROBLEM ===".format(self.name))
        else:
            print("=== {} ===".format(self.name))

        if self._info or self.objects:
            print("Info:")
            for msg, objects in self.objects.items():
                print("{}({}) {}".format(self.prefix, ', '.join(objects), msg))

            for msg in self._info:
                print(self.prefix + msg)

        if self._problem:
            print("Problems:")
            for msg in self._problem:
                print(self.prefix + msg)

        print("---\n")


def detect_ip():
    url = 'http://ifconfig.me/'
    r = requests.get(url)
    if r.status_code != 200:
        log.error('Failed to get IP from {} ({}), use --ip a.b.c.d'.format(url, r.status_code))

    assert(r.status_code == 200)
    return ['127.0.0.1', r.text]


def resolve(name):
    """
    return list of IPs for hostname or raise error
    :param name:
    :return:
    """
    try:
        data = socket.gethostbyname_ex(name)
        return data[2]
    except socket.gaierror:
        log.warning("WARNING: Cannot resolve {}".format(name))
        return list()

def simulate_check(servername, droot, report):
    success = False
    test_data = ''.join(random.choice(
        string.ascii_uppercase + string.ascii_lowercase + string.digits) for _ in range(100))
    test_basename = 'certbot_diag_' + ''.join(random.choice(
        string.ascii_uppercase + string.ascii_lowercase + string.digits) for _ in range(10))
    test_dir = os.path.join(droot, '.well-known', 'acme-challenge')
    test_file = os.path.join(test_dir, test_basename)
    #report.info("Test file path: " + test_file)
    test_url = 'http://' + servername + '/.well-known/acme-challenge/' + test_basename
    #report.info("Test file URL: " + test_url)

    log.debug('create test file ' + test_file)
    os.makedirs(test_dir, exist_ok=True)
    with open(test_file, "w") as f:
        f.write(test_data)

    log.debug('test URL ' + test_url)
    try:
        r = requests.get(test_url, allow_redirects=True)
    except requests.RequestException as e:
        report.problem("URL {} got exception: {}".format(test_url, e))
    else:
        if r.status_code != 200:
            report.problem('URL {} got status code {}. Used DocumentRoot {}. Maybe Alias or RewriteRule working?'.format(
                test_url, r.status_code, droot))
        else:
            if r.text == test_data:
                report.info("Simulated check match root: {}".format(droot), object=servername)
                success = True
            else:
                report.problem("Simulated check fails root: {} url: {}".format(droot, test_url))

    os.unlink(test_file)
    return success


def is_local_ip(hostname, local_ip_list, report):
    iplist = resolve(hostname)
    for ip in iplist:
        if ip in local_ip_list:
            report.info('is local {}'.format(ip), object=hostname)
        else:
            report.problem('{} ({}) not local {}'.format(hostname, ip, local_ip_list))

def get_all_hostnames(hostname, apacheconf=None, root=None):
    names = list()
    vhost = next(yield_vhost(hostname, apacheconf, root=root))
    servername = next(vhost.children('ServerName')).args
    names.append(servername)
    for alias in vhost.children('ServerAlias'):
        names.extend(alias.args.split(' '))
    return names


def get_webroot(hostname, apacheconf, root=None):
    vhost = next(yield_vhost(hostname, apacheconf, root=root))
    return next(vhost.children('DocumentRoot')).args


def yield_vhost(domain, apacheconf=None, root=None):
    if root is None:
        root = a2conf.Node()
        root.read_file(apacheconf)

    for vhost in root.children('<VirtualHost>'):
        if '80' not in vhost.args:
            # log.debug('Skip vhost {}:{} (no 80 in {})'.format(vhost.path, vhost.line, vhost.args))
            continue
        try:
            servername = next(vhost.children('servername')).args
        except StopIteration:
            # log.debug('Skip vhost {}:{} (no ServerName)'.format(vhost.path, vhost.line))
            continue

        if domain.lower() == servername.lower():
            # return vhost
            yield vhost

        for alias in vhost.children('serveralias'):
            if domain.lower() in map(str.lower, alias.args.split(' ')):
                # return vhost
                yield vhost
    raise VhostNotFound('Not found vhost {}'.format(domain))
    # return None

def process_file(leconf_path, local_ip_list, args, leconf=None):

    report = Report(leconf_path or 'manual')

    try:

        if not leconf:
            report.info("LetsEncrypt conf file: " + leconf_path)
            if os.path.exists(leconf_path):
                lc = LetsEncryptCertificateConfig(path=leconf_path)
            else:
                report.problem("Missing LetsEncrypt conf file " + leconf_path)
                raise FatalError
        else:
            lc = leconf

        if args.host and not args.host in lc.domains:
            log.debug('Skip file {}: not found domain {}'.format(leconf_path, args.host))
            return

        # Local IP check
        for domain in lc.domains:
            log.debug("check domain {} from {}".format(domain, leconf_path or 'manual'))
            le_droot = lc.get_droot(domain)

            is_local_ip(domain, local_ip_list, report)
            vhost_list = list(yield_vhost(domain, args.apacheconf))

            if not vhost_list:
                report.problem('Not found domain {} in {}'.format(domain, args.apacheconf))
                raise FatalError

            if len(vhost_list) > 1:
                report.problem('Found {} virtualhost for {} in {}'.format(len(vhost_list), domain, args.apacheconf))
                raise FatalError

            vhost = vhost_list[0]

            report.info('Vhost: {}:{}'.format(vhost.path, vhost.line), object=domain)

            #
            # DocumentRoot exists?
            #
            droot = None
            try:
                droot = next(vhost.children('DocumentRoot')).args
            except StopIteration:
                report.problem("No DocumentRoot in vhost at {}:{}".format(vhost.path, vhost.line))
                raise FatalError
            else:
                if droot is not None and os.path.isdir(droot):
                    report.info("DocumentRoot: {}".format(droot), object=domain)
                else:
                    report.problem("DocumentRoot dir not exists: {} (problem!)".format(droot))

            #
            # Redirect check
            #
            try:
                r = next(vhost.children('Redirect'))
                rpath = r.args.split(' ')[1]
                if rpath in ['/', '.well-known']:
                    report.problem('Requests will be redirected: {} {}'.format(r, r.args))
            except StopIteration:
                # No redirect, very good!
                pass

            #
            # DocumentRoot matches?
            #

            if not args.altroot:
                # No altroot, simple check
                if os.path.realpath(le_droot) == os.path.realpath(droot):
                    report.info('DocumentRoot {} matches LetsEncrypt and Apache'.format(droot), object=domain)
                else:
                    report.problem(
                        'DocRoot mismatch for {}. Apache: {} LetsEncrypt: {}'.format(domain, droot, le_droot))
                # simulate anyway
                simulate_check(domain.lower(), droot, report)
            else:
                # AltRoot
                if os.path.realpath(le_droot) == os.path.realpath(args.altroot):
                    report.info('Domain name {} le root {} matches --altroot'.format(domain, le_droot))
                    simulate_check(domain.lower(), le_droot, report)
                elif os.path.realpath(le_droot) == os.path.realpath(droot):
                    report.info('Domain name {} le root {} matches DocumentRoot'.format(domain, le_droot))
                    simulate_check(domain.lower(), droot, report)
                else:
                    report.problem(
                        'DocRoot mismatch for {}. AltRoot: {} LetsEncrypt: {} Apache: {}'.format(
                            domain, args.altroot, le_droot, droot))

            log.debug("END OF ITER for {}".format(domain))

    except FatalError:
        pass
    # END OF FINISHED PART
    #
    # Final debug
    #
    if report.has_problems() or not args.quiet:
        report.report()

    return

def get_aliases(names, apacheconf):
    """
    gets all aliases from both 'names' and get_all_hostnames
    :param names: any alises to append
    :param apacheconf: apache config file name
    :return:
    """

    aliases = list(names)
    aliases.extend(get_all_hostnames(names[0], apacheconf))

    return set(aliases)


def place(apacheconfig, challenge_dir, token, validation, domain):
    root = Node(apacheconfig)

    try:
        if challenge_dir is None:
            vhost = root.find_vhost(domain)
            droot = vhost.first('DocumentRoot').args
            challenge_dir = os.path.join(droot, '.well-known', 'acme-challenge')

    except VhostNotFound:
        print("Not found vhost for {}".format(domain))
        return

    token_path = os.path.join(challenge_dir, token) 
    # create dir if needed
    pathlib.Path(challenge_dir).mkdir(parents=True, exist_ok=True)

    with open(token_path, "w") as fh:
        fh.write(validation)
    print(domain +" " + token)

def cleanup(apacheconfig, challenge_dir, token, domain):
    root = Node(apacheconfig)
    try:
        if challenge_dir is None:
            vhost = root.find_vhost(domain)
            droot = vhost.first('DocumentRoot').args
            challenge_dir = os.path.join(droot, '.well-known', 'acme-challenge')

    except VhostNotFound:
        print("Not found vhost for {}".format(domain))
        return

    token_path = os.path.join(challenge_dir, token) 
    if os.path.isfile(token_path):
        print("Remove token", token_path)
        os.unlink(token_path)


def get_altnames(path):
        crt = ssl._ssl._test_decode_cert(path)
        altnames = [t[1] for t in crt['subjectAltName'] if t[0] == 'DNS']
        subject = dict(x[0] for x in crt['subject'])
        cn = subject['commonName']
        altnames.remove(cn)
        altnames.insert(0, cn)
        return altnames

def mkcert(webroot, names, opts=None):

    cmd = ['certbot', 'certonly', '--webroot', '-w', webroot]

    for n in names:
        cmd.extend(['-d', n])
    
    if opts:
        cmd.extend(opts)

    print("RUNNING: {}".format(' '.join(cmd)))
    cp = subprocess.run(cmd)
    return cp.returncode


def main():
    global log

    def_apacheconf = '/etc/apache2/apache2.conf'
    def_lepath = '/etc/letsencrypt/renewal/'

    epilog = 'Examples:\n' \
             '# Verify all LetsEncrypt config\n' \
             '{me}\n\n'.format(me=sys.argv[0])

    epilog += "# Verify one LetsEncrypt certificate:\n " \
              "{me} --host example.com\n\n".format(me=sys.argv[0])

    epilog += "# Verify if cert could be requested (preparation). Existing certificate not needed (all manual):\n" \
              "{me} --prepare -d example.com -d www.example.com -w /var/www/virtual/example.com\n" \
              "# Or much more simpler (aliases and webroot will be guessed from apache config):\n" \
              "{me} --prepare -d example.com --aliases\n\n".format(me=sys.argv[0])

    epilog += "# Create certificate for example.com and all of it's aliases" \
        "(www.example.com, example.net, www.example.net)\n" \
        "{me} --create -d example.com --aliases\n" \
        "{me} --create -d example.com -d www.example.com -d example.net -d www.example.net\n" \
        "\n".format(me=sys.argv[0])

    parser = argparse.ArgumentParser(description='Apache2 / Certbot misconfiguration diagnostic', epilog=epilog,
                                     formatter_class=argparse.RawTextHelpFormatter)

    g = parser.add_argument_group('Check existent LetsEncrypt verification (if "certbot renew" fails)')
    g.add_argument('--host', default=None, metavar='HOST',
                   help='Process only letsencrypt config file for HOST. def: {}'.format(None))
    g.add_argument('--process', default=False, action='store_true', help='Process all letsencrypt certificates')
    g.add_argument('--lepath', default=def_lepath, nargs='?', dest='lepath', metavar='LETSENCRYPT_DIR_PATH',
                   help='Lets Encrypt directory def: {}'.format(def_lepath))

    g = parser.add_argument_group('Check non-existent LetsEncrypt verification (if "certbot certonly --webroot" fails)')
    g.add_argument('--prepare', default=False, action='store_true',
                   help='Preparation check (before requesting LetsEncrypt cert). '
                        'You may also use --aliases option')
    g.add_argument('-w', '--webroot', help='DocumentRoot for new website')
    g.add_argument('-d', '--domain', nargs='*', metavar='DOMAIN', help='hostname/domain for new website')

    g = parser.add_argument_group('General options')
    g.add_argument('--altroot', default=None, metavar='DocumentRoot',
                   help='Try also other root (in case if Alias used). def: {}'.format(None))
    g.add_argument('-c', '--conf', dest='apacheconf', nargs='?', default=def_apacheconf, metavar='PATH',
                   help='Config file path (def: {})'.format(def_apacheconf))
    g.add_argument('-v', '--verbose', action='store_true',
                   default=False, help='verbose mode')
    g.add_argument('-q', '--quiet', action='store_true',
                   default=False, help='quiet mode, suppress output for sites without problems')
    g.add_argument('-i', '--ip', nargs='*',
                   help='Default addresses. Autodetect if not specified')
    g.add_argument(metavar='CERTBOT_OPTS', nargs='*', dest='opts',
                   help='options to pass to certbot, e.g. --test-cert')

    g = parser.add_argument_group('Generate new certificate (certbot certonly --webroot)')
    g.add_argument('--create', default=False, action='store_true',
                   help='Create LetsEncrypt certificate (via certbot). Use also -d and --aliases')
    g.add_argument('--recreate', metavar='CERT.pem',
                   help='Re-create certificate, use names from PATH, and webroot from apache config or --webroot')
    g.add_argument('--aliases', action='store_true',
                   default=False, help='Include all ServerName and ServerAlias found in VirtualHost')
    # g.add_argument('--root', default=None, metavar='DocumentRoot', dest='docroot', help='DocumentRoot for HTTP site (if --create both)')

    g = parser.add_argument_group('Assisting in  (certbot certonly --manual)')
    g.add_argument('--place', nargs=2, metavar=('CERTBOT_TOKEN', 'CERTBOT_VALIDATION'),
                   help='place validation value into token file')
    g.add_argument('--challenge-path', '-p', help='Optional path to token directory (/something/.well-known/acme-challenge/)')
    g.add_argument('--cleanup', metavar='CERTBOT_TOKEN', help='cleanup validation token')
    g.add_argument('--altnames', metavar='CERT.PEM', help='list altnames for certificate')


    args = parser.parse_args()

    logging.basicConfig(
        # format='%(asctime)s %(message)s',
        format='%(message)s',
        # datefmt='%Y-%m-%d %H:%M:%S',
        level=logging.INFO)

    log = logging.getLogger('diag')

    if args.verbose:
        log.setLevel(logging.DEBUG)
        log.debug('Verbose mode')

    if args.ip:
        local_ip_list = args.ip
    else:
        log.debug("Autodetect IP")
        local_ip_list = detect_ip()
    log.debug("my IP list: {}".format(local_ip_list))


    if args.prepare:
        if not args.domain:
            print("--prepare requires at least one -d and optional --aliases")
            sys.exit(1)
        aliases = get_aliases(args.domain, args.apacheconf) if args.aliases else args.domain
        webroot = args.webroot or get_webroot(args.domain[0], args.apacheconf)
        lc = LetsEncryptCertificateConfig(path=None, webroot=webroot, domains=aliases)
        process_file(leconf_path=None, local_ip_list=local_ip_list, args=args, leconf=lc)


    elif args.create:

        if not args.domain and not args.altnames:
            print("--domain requires at least one -d and optional --aliases")
            sys.exit(1)

        name = args.domain[0]

        print("Create cert for {}".format(name))
        
        try:
            webroot = get_webroot(name, args.apacheconf)
        except StopIteration as e:
            print("Cannot find webroot for {} (starting from {}). Maybe site is not yet created/enabled?".format(
                name, args.apacheconf))
            return

        aliases = get_aliases(args.domain, args.apacheconf) if args.aliases else list()

        cmd = ['certbot', 'certonly', '--webroot', '-w', webroot, '-d', name]

        if args.aliases:
            # add all aliases
            # remove main name from aliases
            if name in aliases:
                aliases.remove(name)
            for alias in aliases:
                cmd.extend(['-d', alias])
        else:
            # add all other names
            for domain in args.domain[1:]:
                cmd.extend(['-d', domain])

        if args.opts:
            cmd.extend(args.opts)

        print("RUNNING: {}".format(' '.join(cmd)))
        cp = subprocess.run(cmd)
        sys.exit(cp.returncode)

    elif args.recreate:
        pemfile = args.recreate
        print("recreate", pemfile)
        names = get_altnames(pemfile)
        try:
            webroot = args.webroot or get_webroot(names[0], args.apacheconf)
        except VhostNotFound as e:
            print(e)
            sys.exit(1)
        print(f"cert for {names} webroot {webroot}")
        mkcert(webroot, names)

    elif args.place:
        place(
            apacheconfig=args.apacheconf, 
            challenge_dir=args.challenge_path,
            token=args.place[0], 
            validation=args.place[1], 
            domain=args.domain[0])
    elif args.cleanup:
        cleanup(
            apacheconfig=args.apacheconf,
            challenge_dir=args.challenge_path, 
            token=args.cleanup, 
            domain=args.domain[0])

    elif args.altnames:

        altnames = get_altnames(args.altnames)
        print(' '.join(altnames))

    elif args.process and os.path.isdir(args.lepath):
        for f in os.listdir(args.lepath):
            path = os.path.join(args.lepath, f)
            if not path.endswith('.conf'):
                continue
            if not (os.path.isfile(path) or os.path.islink(path)):
                continue
            process_file(leconf_path=path, local_ip_list=local_ip_list, args=args)
    elif args.process and os.path.isfile(args.lepath):
        process_file(leconf_path=args.lepath, local_ip_list=local_ip_list, args=args)
    else:
        print("Need command")


main()
