# -*- coding: utf-8 -*-
#
# Copyright 2017 Joseph Weston
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
"""Command-line interface to the NordVPN client."""

import sys
import traceback
import signal
import logging
import argparse
import asyncio
import ipaddress

import structlog
from termcolor import colored
import aiohttp
import aiohttp.web

from . import api, vpn, __version__
from . import web as nord_web
from ._utils import sudo_requires_password, prompt_for_sudo, LockError


class Abort(RuntimeError):
    """Signal the command-line interface to abort."""


def main():
    """Execute the nord command-line interface"""
    # parse command line arguments
    args = parse_arguments()

    command = globals()[args.command]

    setup_logging(args)

    # set up the event loop
    loop = asyncio.get_event_loop()
    for sig in (signal.SIGHUP, signal.SIGINT, signal.SIGTERM):
        loop.add_signal_handler(sig, cancel_all_tasks)

    # dispatch
    try:
        returncode = loop.run_until_complete(command(args))
    except asyncio.CancelledError:
        returncode = 1
    except Abort as error:
        print(f"{colored('Error', 'red', attrs=['bold'])}:", error)
        returncode = 1
    finally:
        remaining_tasks = cancel_all_tasks()
        if remaining_tasks:
            loop.run_until_complete(asyncio.wait(remaining_tasks))
        loop.close()

    sys.exit(returncode)


def cancel_all_tasks():
    """Cancel all outstanding tasks on the default event loop."""
    remaining_tasks = asyncio.Task.all_tasks()
    for task in remaining_tasks:
        task.cancel()
    return remaining_tasks


def render_logs(logger, _, event):
    """Render logs into a format suitable for CLI output."""
    if event.get('stream', '') == 'status':
        if event['event'] == 'up':
            msg = colored('connected', 'green', attrs=['bold'])
        elif event['event'] == 'down':
            msg = colored('disconnected', 'red', attrs=['bold'])
    elif event.get('stream', '') == 'stdout':
        msg = f"[stdout @ {event['timestamp']}] {event['event']}"
    elif event.get('exc_info'):
        msg = traceback.format_exception(*event['exc_info'])
    else:
        msg = f"{event['event']}"
    return f"[{colored(logger.name, attrs=['bold'])}] {msg}"


def setup_logging(args):
    """Set up logging."""
    cfg = structlog.get_config()
    cfg['processors'].append(render_logs)

    logging.basicConfig(
        stream=sys.stdout,
        level=(logging.DEBUG if hasattr(args, 'debug') and args.debug
               else logging.INFO),
        format='%(message)s',
    )

    # silence 'asyncio' logging
    logging.getLogger('asyncio').propagate = False


def parse_arguments():
    """Return a parser for the Nord command-line interface."""
    parser = argparse.ArgumentParser(
        'nord',
        description='An unofficial NordVPN client')
    subparsers = parser.add_subparsers(dest='command')

    parser.add_argument('--version', action='version',
                        version=f'nord {__version__}')

    subparsers.add_parser(
        'ip_address',
        help="Get our public IP address, as reported by NordVPN.")

    connect_parser = subparsers.add_parser(
        'connect',
        help="connect to a NordVPN server",
        description="Connect to a nordVPN server. If the '--server' argument "
                    "is provided, connect to that specific server, otherwise "
                    "select all hosts in the provided country, filter them "
                    "by their load, and select the closest one.")
    connect_parser.add_argument('--debug', action='store_true',
                                help='Print debugging information')
    connect_parser.add_argument('-u', '--username', type=str,
                                required=True,
                                help='NordVPN account username')
    # methods of password entry
    passwd = connect_parser.add_mutually_exclusive_group(required=True)
    passwd.add_argument('-p', '--password', type=str,
                        help='NordVPN account password')
    passwd.add_argument('-f', '--password-file', type=argparse.FileType(),
                        help='Path to file containing NordVPN password')

    # pre-filters on the hostlist. Either specify a country or a single host
    hosts = connect_parser.add_mutually_exclusive_group(required=True)

    def _flag(country):
        country = str(country).upper()
        if len(country) != 2 or not str.isalpha(country):
            raise argparse.ArgumentTypeError(
                'must be a 2 letter country code')
        return country

    hosts.add_argument('country_code', type=_flag, nargs='?',
                       help='2-letter country code, e.g. US, GB')
    hosts.add_argument('-s', '--server',
                       help='NordVPN host or fully qualified domain name, '
                            'e.g us720, us270.nordvpn.com')

    # arguments to filter the resulting hostlist
    connect_parser.add_argument('--ping-timeout', type=int, default=2,
                                help='Wait for this long for responses from '
                                     'potential hosts')
    connect_parser.add_argument('--max-load', type=int, default=70,
                                help='Reject hosts that have a load greater '
                                     'than this threshold')

    web_parser = subparsers.add_parser(
        'web',
        help="Run nord as a web app",
        description="Serve a web app that provides a GUI for selecting the "
                    "country to connect to.")

    web_parser.add_argument('--debug', action='store_true',
                            help='Print debugging information')
    web_parser.add_argument('-u', '--username', type=str,
                            required=True,
                            help='NordVPN account username')
    web_parser.add_argument('-P', '--port', type=int, default=8000,
                            help='Port on which to run the web app')
    web_parser.add_argument('-H', '--host', type=ipaddress.ip_address,
                            default='127.0.0.1',
                            help='IP address on which to run the web app')
    # methods of password entry
    passwd = web_parser.add_mutually_exclusive_group(required=True)
    passwd.add_argument('-p', '--password', type=str,
                        help='NordVPN account password')
    passwd.add_argument('-f', '--password-file', type=argparse.FileType(),
                        help='Path to file containing NordVPN password')

    args = parser.parse_args()

    if not args.command:
        parser.error('no command provided')

    return args


# Subcommands

async def ip_address(_):
    """Get our public IP address."""
    async with api.Client() as client:
        print(await client.current_ip())


async def connect(args):
    """Connect to a NordVPN server."""

    username = args.username
    password = args.password or args.password_file.readline().strip()

    # Group requests together to reduce overall latency
    async with api.Client() as client:
        output = await asyncio.gather(
            _get_host_and_config(client, args),
            client.valid_credentials(username, password),
            client.dns_servers(),
            sudo_requires_password(),
        )
    (host, config), valid_credentials, dns_servers, require_sudo = output

    if not valid_credentials:
        raise Abort('invalid username/password combination')

    log = structlog.get_logger(__name__)
    log.info(f"connecting to {host}")

    if require_sudo:
        print('sudo password required for OpenVPN')
        try:
            await prompt_for_sudo()
        except PermissionError:
            # 'sudo' will already have notified the user about the failure
            raise Abort()

    try:
        await vpn.run(config, username, password, dns_servers)
    except LockError:
        raise Abort('Failed to obtain a lock: is another instance '
                    'of nord running?')
    except vpn.OpenVPNError as error:
        raise Abort(str(error))


async def web(args):
    """Run nord as a web app"""

    username = args.username
    password = args.password or args.password_file.readline().strip()

    # Group requests together to reduce overall latency
    async with api.Client() as client:
        output = await asyncio.gather(
            client.valid_credentials(username, password),
            sudo_requires_password(),
        )
        valid_credentials, require_sudo = output

        if not valid_credentials:
            raise Abort('invalid username/password combination')

        if require_sudo:
            print('sudo password required for OpenVPN')
            try:
                await prompt_for_sudo()
            except PermissionError:
                # 'sudo' will already have notified the user about the failure
                raise Abort()

        app = nord_web.init_app(client, (username, password))
        runner = aiohttp.web.AppRunner(app)
        await runner.setup()
        site = aiohttp.web.TCPSite(runner, str(args.host), args.port)
        await site.start()
        print(colored(f'=== Listening {args.host}:{args.port} ===',
                      color='white', attrs=['bold']))
        try:
            await app['shutdown_signal'].wait()
        finally:
            await runner.cleanup()


async def _get_host_and_config(client, args):
    # get the host
    if args.server:
        try:
            hosts = [api.normalized_hostname(args.server)]
        except ValueError as error:
            raise Abort(f'{args.server} is not a NordVPN server')
    else:
        assert args.country_code
        hosts = await client.rank_hosts(args.country_code,
                                        args.max_load, args.ping_timeout)
        if not hosts:
            raise Abort('no hosts available '
                        '(try a higher load or ping threshold?)')
    # get the config
    for host in hosts:
        try:
            config = await client.host_config(host)
            return host, config
        except aiohttp.ClientResponseError as error:
            if error.code != 404:
                raise  # unexpected error
    # pylint: disable=undefined-loop-variable
    raise Abort(f"config unavailable for {host}")