#!python
from builtins import input

import os
import json
import click
import itertools
from glob import glob

from abraia.utils import load_image, save_image, get_type, list_dir, process_map
from abraia.editing import edit_image

from abraia import config
from abraia import Abraia
from abraia import APIError
from abraia import __version__


abraia = Abraia()


def list_files(folder):
    return abraia.list_files(folder)


def upload_file(file, folder):
    return abraia.upload_file(file, folder)


def download_file(path, folder):
    dest = os.path.join(folder, os.path.basename(path))
    return abraia.download_file(path, dest)


def remove_file(path):
    return abraia.remove_file(path)


def echo_error(error):
    click.echo('[' + click.style('Error {}'.format(error.code),
                                 fg='red', bold=True) + '] {}'.format(error.message))


def input_files(src):
    src = os.path.join(src, '**/*') if os.path.isdir(src) else src
    return glob(src, recursive=True)


@click.group('abraia')
@click.version_option(__version__)
def cli():
    """Abraia CLI tool"""
    pass


@cli.command()
def configure():
    """Configure the abraia api key."""
    click.echo('Go to [' + click.style('https://abraia.me/editor/', fg='green') + '] to get your user id and key\n')
    try:
        abraia_id, abraia_key = config.load()
        abraia_id = click.prompt('Abraia Id', default=abraia_id)
        abraia_key = click.prompt('Abraia Key', default=abraia_key)
        config.save(abraia_id, abraia_key)
    except:
        pass


def editing_file(src, dest, mode, size=None):
    img = load_image(src)
    out = edit_image(img, mode, size=size)
    save_image(out, dest)
    return dest


def editing_files(src, mode, size=None, desc="Editing"):
    try:
        inputs = input_files(src)
        new_ext = 'png' if mode == 'removebg' else 'jpg'
        dirname = src if os.path.isdir(src) else os.path.dirname(src)
        outputs = [os.path.join('output', f"{os.path.splitext(os.path.relpath(src, dirname))[0]}.{new_ext}") for src in inputs]
        process_map(editing_file, inputs, outputs, itertools.repeat(mode), itertools.repeat(size), desc=desc)
    except APIError as error:
        echo_error(error)


@cli.group('editing')
def cli_editing():
    """Convert and edit images in bulk."""
    pass


@cli_editing.command()
@click.option('--width', help='Width of the cropped image', type=int)
@click.option('--height', help='Height of the cropped image', type=int)
@click.argument('src')
def smartcrop(src, width, height):
    """Smart crop images to the specified size."""
    size = (width, height) if width and height else None
    editing_files(src, 'smartcrop', size=size, desc="Smart cropping")


@cli_editing.command()
@click.argument('src')
def removebg(src):
    """Remove the images background to make them transparent."""
    editing_files(src, 'removebg', desc="Removing background")


@cli_editing.command()
@click.argument('src')
def upscale(src):
    """Upscale and enhance images increasing the resolution."""
    editing_files(src, 'upscale', desc="Upscaling images")


@cli_editing.command()
@click.argument('src')
def anonymize(src):
    """Anonymize image blurring faces and car license plates."""
    editing_files(src, 'anonymize', desc="Anonymizing images")


@cli_editing.command()
@click.argument('src')
def blur(src):
    """Blur the image background to focus attention on the main object."""
    editing_files(src, 'blur', desc="Blur background")


@cli_editing.command()
@click.argument('src')
def clean(src):
    """Clean images removing unwanted objects with inpainting."""
    editing_files(src, 'clean', desc="Removing objects")


@cli.group('files')
def cli_files():
    """Manage files on the cloud storage."""
    pass


def format_output(files, folders=[]):
    output = '\n'.join(['{:>28}  {}/'.format('', click.style(f['name'], fg='blue', bold=True)) for f in folders]) + '\n'
    output += '\n'.join(['{}  {:>7}  {}'.format(f['date'], f['size'], f['name']) for f in files])
    output += '\ntotal {}'.format(len(files))
    return output


@cli_files.command()
@click.argument('folder', required=False, default='')
def list(folder):
    """List files in abraia."""
    try:
        files, folders = list_files(folder)
        click.echo(format_output(files, folders))
    except APIError as error:
        echo_error(error)


@cli_files.command()
@click.argument('src', type=click.Path())
@click.argument('folder', required=False, default='')
def upload(src, folder):
    """Upload files to abraia."""
    try:
        files = input_files(src)
        process_map(upload_file, files, itertools.repeat(folder), desc="Uploading")
    except APIError as error:
        echo_error(error)


@cli_files.command()
@click.argument('path')
@click.argument('folder', required=False, default='')
def download(path, folder):
    """Download files from abraia."""
    try:
        files = list_files(path)[0]
        paths = [file['path'] for file in files]
        process_map(download_file, paths, itertools.repeat(folder), desc="Downloading")
    except APIError as error:
        echo_error(error)


@cli_files.command()
@click.argument('path')
def remove(path):
    """Remove files from abraia."""
    try:
        files = list_files(path)[0]
        click.echo(format_output(files))
        if files and click.confirm('Are you sure you want to remove the files?'):
            paths = [file['path'] for file in files]
            process_map(remove_file, paths, desc="Removing")
    except APIError as error:
        echo_error(error)


@cli_files.command()
@click.option('--remove', help='Remove file metadata', is_flag=True)
@click.argument('path')
def metadata(path, remove):
    """Load or remove file metadata."""
    try:
        if remove:
            abraia.remove_metadata(path)
        meta = abraia.load_metadata(path)
        click.echo(json.dumps(meta, indent=2))
    except APIError as error:
        echo_error(error)


@cli.command()
def list():
    """List available datasets."""
    from abraia.training import list_datasets
    click.echo(json.dumps(list_datasets(), indent=2))


def convert_to_jpg(src):
    img = load_image(src)
    save_image(img, f"{os.path.splitext(src)[0]}.jpg")
    os.remove(src)


def process_dataset(project):
    files = list_dir(project)
    heics = [f for f in files if get_type(f) == 'image/heic']
    if len(heics):
        process_map(convert_to_jpg, heics, desc="Converting images")
        files = list_dir(project)
    tiffs = [f for f in files if get_type(f) == 'image/tiff']
    if len(tiffs):
        from abraia.multiple import create_visible
        process_map(create_visible, tiffs, desc="Creating images")
        files = list_dir(project)
    return files


@cli.command()
@click.argument('project')
@click.argument('query', required=False, default='')
def create(project, query):
    """Create or update a dataset."""
    from abraia.training import search_images, load_dataset
    dataset = load_dataset(project)
    output_dir = f"{project}/"
    files = []
    if query:
        files = search_images(query, limit=100, output_dir=output_dir, verbose=True)
    if os.path.exists(project):
        files = process_dataset(project)
    process_map(upload_file, files, itertools.repeat(project + '/'), desc="Uploading images")
    dataset.save()


@cli.command()
@click.argument('project')
@click.argument('label', type=str, required=False, default='')
@click.option('--segment', help='Segment objects from boxes', is_flag=True)
def annotate(project, label, segment=False):
    """Annotate a dataset using Grounding Dino."""
    from abraia.training import load_dataset
    if label:
        dataset = load_dataset(project)
        annotations = dataset.annotate(label, segment=segment)
        dataset.save(annotations)


@cli.command()
@click.argument('project')
@click.argument('epochs', type=int, required=False, default=None)
def train(project, epochs):
    """Train a model on the specified dataset."""
    from abraia.training import load_dataset, prepare_dataset, ModelTrainer
    click.echo('Loading dataset...')
    dataset = load_dataset(project)
    prepare_dataset(dataset)
    click.echo('Training model...')
    training_session = ModelTrainer(project, dataset.task, dataset.classes)
    training_session.train(epochs)
    click.echo('Saving model...')
    training_session.save()


@cli.command()
@click.argument('project')
@click.argument('src', required=False, default=0)
def predict(project, src):
    """Run trained model on a specified image."""
    try:
        from abraia.training import list_models
        from abraia.inference.detect import Model
        from abraia.utils import process_media, render_results
        models = list_models(project)
        if not models:
            click.echo(f"No trained model found in project '{project}'")
            return
        model_name = models[0]
        model_uri = f"{abraia.userid}/{project}/{model_name}"
        model = Model(model_uri)

        def callback(img):
            results = model.run(img)
            return render_results(img, results)
        
        process_media(src, callback)
    except APIError as error:
        echo_error(error)


if __name__ == '__main__':
    if not abraia.userid:
        configure()
    else:
        cli()
