#!/usr/bin/python3

from mpi4py import MPI
from pyfaidx import Fasta
from termcolor import colored

from A2G.__version__ import version
from A2G.align2consensus import *


def split_n_yield(fasta, chunks):
    fa = Fasta(fasta)
    for seq in np.array_split(list(fa.values()), chunks):
        yield '\n'.join(['>%s/n%s' % (sq.long_name, sq) for sq in seq])


parser = argparse.ArgumentParser()
parser.add_argument('global_consensus', help='Sequence consensus of the '
                                             'global region, e.g. full COI')
parser.add_argument('local_consensus',
                    help='Sequence consensus of the local region, e.g. '
                         'Leray fragment')
parser.add_argument('fasta', help='fasta file with the focal sequences')
parser.add_argument('--out_prefix', action='store', default='A2G_aln',
                    help='Prefix of outputs')
parser.add_argument('--remove_duplicates', action='store_false',
                    help='Keep or remove duplicated sequences',
                    default=True)

args = parser.parse_args()

comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()

if rank == 0:
    print('\nA2G_mpi version:', colored(version, None, attrs=["bold", "blink"]))
    print(colored('Copyright 2020', 'red', attrs=["bold"]), 'Jose Sergio Hleap\n')
    print("-" * 78)
    print(" Running on %d cores" % size)
    print("-" * 78)

aln = Align(gene_consensus=args.global_consensus, cpus=1,
            amplicon_consensus=args.local_consensus, no_write=True,
            out_prefix=args.out_prefix)

if rank == 0:
    data = list(split_n_yield(args.fasta, size))
else:
    data = None

data = comm.scatter(data, root=0)
aln.query = data
results = aln.run()
newData = comm.gather(results, root=0)
if newData is not None:
    full, _ = zip(*newData)
if rank == 0:
    print('master:', newData)
    print('MASTER LENGH', len(newData))
    with open('result.aln', 'w') as f:
        f.write('\n'.join(full))

# Gather results on rank 0.
results = comm.gather(results, root=0)

if comm.rank == 0:
    # Flatten list of lists.
    results = [_i for temp in results for _i in temp]
    fasta, subset = zip(*results)
    with open('%s_aligned.fasta' % args.out_prefix, 'w') as o, open(
            '%s_aligned.withoutliers' % args.out_prefix, 'w') as w:
        o.write('\n'.join(fasta))
        w.write('\n'.join(subset))
