import gzip
import json
import shutil
from pathlib import Path
import logging

from Bio import SeqIO

logging.basicConfig()
logging.getLogger().setLevel(logging.INFO)

from yaml import load
from cortexpy.command import get_exit_code_yaml_path
CORTEXPY_EXIT_CODES = load(open(get_exit_code_yaml_path(), 'rt'))

with open('args.json') as fh:
    ARGS = json.load(fh)

FASTX_INPUT = {
    'forward': ARGS['fastx_forward'],
    'reverse': ARGS['fastx_reverse'],
    'single': ARGS['fastx_single'],
}

KALLISTO_INPUT = {
    'forward': ARGS['kallisto_fastx_forward'],
    'reverse': ARGS['kallisto_fastx_reverse'],
    'single': ARGS['kallisto_fastx_single'],
}

MCCORTEX = f'mccortex {ARGS["kmer_size"]}'
MCCORTEX_ARGS = f'--sort --force -m {ARGS["memory"]}G'
if ARGS['quiet']:
    MCCORTEX_ARGS += ' --quiet'

rule all:
    input:
        dynamic('transcripts/g{sg_id}.transcripts.fa.gz')


rule full_cortex_graph:
    input: [v for v in FASTX_INPUT.values() if v is not None]
    output: 'cortex_graph/full.ctx'
    threads: 16
    run:
        cmd = [MCCORTEX, 'build', MCCORTEX_ARGS, '--threads', threads, '--kmer', ARGS['kmer_size'],
               '--sample', 'abeona']
        if FASTX_INPUT['forward'] is not None:
            cmd += ['--seq2', f'{FASTX_INPUT["forward"]}:{FASTX_INPUT["reverse"]}']
        if FASTX_INPUT['single'] is not None:
            cmd += ['--seq', FASTX_INPUT['single']]
        cmd.append(output)
        cmd = [str(c) for c in cmd]
        shell(' '.join(cmd))

rule clean_cortex_graph:
    input: 'cortex_graph/full.ctx'
    output: 'cortex_graph/full.clean.ctx'
    threads: 16
    run:
        shell(f'{MCCORTEX} clean {MCCORTEX_ARGS}'
              f' --threads {threads}'
              f' -T0 -U{ARGS["min_unitig_coverage"]}'
              f' --out {output} {input}')

rule prune_cortex_graph_of_tips:
    input: 'cortex_graph/full.clean.ctx'
    output: f'cortex_graph/full.clean.min_tip_length_{ARGS["min_tip_length"]}.ctx'
    run:
        if ARGS["min_tip_length"] is None:
            shell('cp {input} {output}')
        else:
            shell(f'cortexpy prune --out {output} {input} --remove-tips {ARGS["min_tip_length"]}')

rule traverse_cortex_subgraph:
    input: f'cortex_graph/full.clean.min_tip_length_{ARGS["min_tip_length"]}.ctx'
    output: dynamic('traversals/g{sg_id}.traverse.ctx')
    threads: 16
    run:
        out_dir = Path(output[0]).parent
        cmd = f'python -m abeona subgraphs {input} {out_dir} -m {ARGS["memory"]} -c {threads}'
        if ARGS['initial_contigs'] is not None:
            cmd += f' --initial-contigs {ARGS["initial_contigs"]}'
        shell(cmd)

rule candidate_transcripts:
    input: 'traversals/g{sg_id}.traverse.ctx'
    output: 'candidate_transcripts/g{sg_id}.transcripts.fa.gz'
    run:
        output_tmp = str(output) + '.tmp'
        out_file = output[0]
        shell(f"""
        set +e
        cortexpy view traversal {input} --max-paths {ARGS["max_paths_per_subgraph"]} \
               | gzip -c > {output_tmp}
        exitcode=$?
        set -e
        if [ $exitcode -eq 0 ]; then
            mv -f {output_tmp} {out_file}
            touch {out_file}.ok
        elif [ $exitcode -eq {CORTEXPY_EXIT_CODES["MAX_PATH_EXCEEDED"]} ]; then
            mv -f {output_tmp} {out_file}
        else
            exit $exitcode
        fi
        """)

rule create_transcripts_for_subgraph:
    input:
        candidate_transcripts='candidate_transcripts/g{sg_id}.transcripts.fa.gz',
        fastxs = [v for v in KALLISTO_INPUT.values() if v is not None]
    output: 'transcripts/g{sg_id}.transcripts.fa.gz'
    run:
        # If the .ok flag is not set, then this set of candidate transcripts is incomplete
        # and we shall ignore it
        if not Path(input.candidate_transcripts + '.ok').exists():
            shell('touch {output}')
            return
        kallisto_index = build_kallisto_index(
            input=input[0],
            output=f'kallisto_indices/g{wildcards.sg_id}.transcripts.ki'
        )

        out_dir = kallisto_quant(
            kallisto_index=kallisto_index,
            out_dir=f'kallisto_quant/g{wildcards.sg_id}'
        )

        filter_transcripts(kallisto_quant_dir=out_dir, output=output[0], wildcards=wildcards)


def build_kallisto_index(input, output):
    Path(output).parent.mkdir(exist_ok=True)
    cmd = f'kallisto index -i {output} {input}'
    if int(ARGS['kmer_size']) < 31:
        cmd += f' --kmer-size {ARGS["kmer_size"]}'
    shell(cmd)
    return output


def kallisto_quant(kallisto_index, out_dir):
    Path(out_dir).mkdir(exist_ok=True, parents=True)
    cmd = (f'kallisto quant -i {kallisto_index} --output-dir {out_dir}'
           f' -b {ARGS["bootstrap_samples"]} --plaintext')
    if ARGS['kallisto_fastx_forward'] is not None:
        cmd += f' {ARGS["kallisto_fastx_forward"]} {ARGS["kallisto_fastx_reverse"]}'
    else:
        cmd += (
            f' -l {ARGS["kallisto_fragment_length"]} -s {ARGS["kallisto_sd"]}'
            f' --single {ARGS["kallisto_fastx_single"]}'
        )
    shell(cmd)
    return out_dir


def filter_transcripts(kallisto_quant_dir, output, wildcards):
    # filter transcripts
    import pandas as pd
    import numpy as np
    Path(output).parent.mkdir(exist_ok=True)
    logger = logging.getLogger('abeona.assembly.filter_transcripts')
    bootstraps = []
    input_abundance = [f'{kallisto_quant_dir}/bs_abundance_' + f'{i}.tsv' for i in
                       range(ARGS['bootstrap_samples'])]
    for bs_abundance in input_abundance:
        bootstraps.append(
            pd.read_csv(bs_abundance, sep='\t', dtype={'target_id': str, 'length': int}))
    bootstraps = pd.concat(bootstraps)
    est_count_threshold = 1
    keep_prop = 0.95
    ge1_counts = bootstraps.groupby('target_id')['est_counts'].aggregate(
        lambda x: np.sum(x >= est_count_threshold) / len(x))

    keep_counts = ge1_counts[ge1_counts >= keep_prop]

    logging.info(
        f'Keeping contigs with >= {keep_prop} of bootstrapped est_counts >= {est_count_threshold}')
    logger.info(f'keeping {len(keep_counts)} out of {len(ge1_counts)} contigs.')

    filtered_records = filter_and_annotate_contigs(keep_counts,
                                                   f'candidate_transcripts/g{wildcards.sg_id}.transcripts.fa.gz')
    logger.info(f'Writing filtered records to {output}')
    with gzip.open(str(output), 'wt') as fh:
        SeqIO.write(filtered_records, fh, "fasta")


def filter_and_annotate_contigs(filtered_counts, candidate_transcripts):
    with gzip.open(str(candidate_transcripts), 'rt') as fh:
        for record in SeqIO.parse(fh, "fasta"):
            if record.id in filtered_counts.index:
                record.description = 'prop_bs_est_counts_ge_1={}'.format(filtered_counts.at[record.id])
                yield record

