#!/usr/bin/env python 

"""Some utility function for operating on a cluster or MP machine."""

__author__ = "Jens Reeder"
__copyright__ = "Copyright 2010, Jens Reeder, Rob Knight"
__credits__ = ["Jens Reeder", "Rob Knight"]
__license__ = "GPL"
__version__ = "0.91"
__maintainer__ = "Jens Reeder"
__email__ = "jens.reeder@gmail.com"
__status__ = "Release"

from os import remove, system
from string import join, lowercase
from os.path import exists
from time import sleep
from random import sample

from asynchat import async_chat
from socket import socket, AF_INET, SOCK_STREAM, gethostname, error

from cogent.app.util import ApplicationNotFoundError

from Denoiser.settings import *

def submit_jobs(commands, prefix):
    """submit jobs using CLUSTER_JOBS_SCRIPTS."""

    if not CLUSTER_JOBS_SCRIPT or not exists(CLUSTER_JOBS_SCRIPT):
        raise ApplicationNotFoundError,"CLUSTER_JOBS_SCRIPT in setting.py not set!"
    fh = open(prefix+"_commands.txt","w") 
    fh.write("\n".join(commands))
    fh.close()
    system('%s -ms %s %s'%(CLUSTER_JOBS_SCRIPT, prefix+"_commands.txt", prefix))
    remove(prefix+"_commands.txt")

def setup_workers(num_cpus, outdir, server_socket, queue=None, verbose=True,
                  error_profile='Data/FLX_error_profile.dat'):
    """Start workers waiting for data."""

    workers = []
    client_sockets = []
    tmpname =  "".join(sample(list(lowercase),8)) #id for cluster job

    host, port = server_socket.getsockname()

    #TODO: this should be set to a defined wait time using alarm()
    for i in range(num_cpus):
        name = outdir+("/%sworker%d" % (tmpname, i))
        workers.append(name)
        cmd  = "%s %s -f %s -s %s -p %s" % (PYTHON_BIN, DENOISE_WORKER, name, host, port)
        if verbose:
            cmd += " -v"
        if error_profile:
            cmd += " -e %s" % error_profile
            
        submit_jobs([cmd], tmpname)
        #wait until the client connects
        #This might be a race condition -> make the client robust
        client_socket, client_address = server_socket.accept()
        client_sockets.append((client_socket, client_address))

    return workers, client_sockets

def adjust_workers(num_flows, num_cpus, worker_sockets, log_fh=None):
    """Stop workers no longer needed.

    num_flows: number of flowgrams
    num_cpus: number of CPUs currently used
    worker_sockets: list of connected sockets
    log_fh: open fh to log file
    """
    if(num_flows < (num_cpus-1)*MIN_PER_CORE):       
        if log_fh:
            log_fh.write("Adjusting number of workers:\n")
            log_fh.write("flows: %d   cpus:%d\n" % (num_flows, num_cpus))
        # TODO: make sure this works with __future__ division
        per_core = max(MIN_PER_CORE, (num_flows/num_cpus)+1)
        for i in range (num_cpus):
            if(i*per_core > num_flows):
                worker_sock = worker_sockets.pop()
                worker_sock.close()
                num_cpus = num_cpus-1
                if log_fh:
                    log_fh.write("released worker %d\n"% i)

        assert(num_cpus==len(worker_sockets))                 
        if log_fh:
            log_fh.write("New number of cpus:%d\n"% num_cpus)
    
    return num_cpus

def stop_workers(worker_sockets, log_fh=None):
    """Stop all worker proccesses.

    worker_sockets:  list of connected sockets
    log_fh: open fh to log file
    """
    for i,worker in enumerate(worker_sockets):
        try:
            worker.send("Server shutting down all clients")
        except error:
            #socket already closed, client dead
            if log_fh:
                log_fh.write("Worker %s seems to be dead already. Check for runaways!\n" % i)
        worker.close()

def check_workers(workers, worker_sockets, log_fh=None):
    """Check if all workers are still alive. Exit otherwise.

    workers: list of worker names
    worker_sockets: list of connected sockets
    """

    # Do a dummy send and see if it fails
    for worker,sock in zip(workers,worker_sockets):
        try:
            sock.send("")
        except error:   
            if log_fh:
                log_fh.write("FATAL ERROR\nWorker %s not alive. Aborting\n" % worker)
            stop_workers(worker_sockets, log_fh)
            return False
    return True

def setup_server(port=0, verbose=False):
    """Open a port on the server for workers to connect to.
    
    port: the port number to use, 0 means let OS decide
    """

    host = gethostname()
    sock = socket(AF_INET, SOCK_STREAM)
    try:
        sock.bind((host, port))
    except error,msg:
        raise error, "Could not open Socket on server: " + str(msg)
    sock.listen(5) #max num of queued connections usually [1..5]
    if verbose:
        print "Server listening on %s" % str(sock.getsockname())
    return sock

def save_send(socket, data):
    """send data to a socket.

    socket:  a connected socket object
    data: string to send over the socket
    """

    #We have no control about how much data the clients accepts
    #we send in chunks until done
    while len(data)>0:
        try:
            send_data_size = socket.send(data)
            #remove sent portion form data
            data = data[send_data_size:]
        except error, msg:
            #most likely socket busy, buffer full or not yet ready
            sleep(0.01)

def sendFlowgramToSocket(identifier, flowgram, socket, trim = False):
    """send one flowgram over a socket.

    id: identifier of this flowgram

    flowgram: the flowgram itself

    socket: socket to write to
    
    trim: Boolean flag for quality trimming flowgrams 
    """

    if trim:
        flowgram = flowgram.getQualityTrimmedFlowgram()

    #store space separated string representation of flowgram
    #storing this is much quicker than re-generating everyt we send it
    if (not hasattr(flowgram, "spaced_flowgram")):
        spaced_flowgram_seq = " ".join(map(str, flowgram.flowgram))
        flowgram.spaced_flowgram = spaced_flowgram_seq
    else:
        spaced_flowgram_seq = flowgram.spaced_flowgram

    data = "%s %d %s\n" % (identifier, len(flowgram), spaced_flowgram_seq)
    save_send(socket, data)
        
class client_handler(async_chat):
    """A convenience wrapper around a socket to collect incoming data"""

    #This handler is called from the main routine with an open socket.
    #It waits for the client to return its data on the socket and stores it
    # in a global variable result_array. Afterwards the handler is deleted
    # by removing it from the global asyncore map
    #Note: the incomgn socket is expected to be connected upon initialization
    #      and remains connected after this handler is destroyed

    def __init__(self, sock, worker_number, result_array, log_fh=None):
        async_chat.__init__(self, sock)
        self.in_buffer = []
        self.set_terminator("--END--")
        self.number = worker_number
        self.results = result_array
        if log_fh:
            log_fh.write("Started client handler on %s: %s\n" % self.addr)

    def collect_incoming_data(self, data):
        """Buffer the data"""

        #Note data might come in chunks of arbitrary size
        self.in_buffer.append(data)
        
    def found_terminator(self):
        """Action performed when the terminator is found."""
        #Note this function is event-triggered

        #Data on sockets comes in chunks of strings. Cat first then split on \n
        data = "".join(self.in_buffer)
        self.results[self.number] = [map(float, (s.split())) for s in data.split("\n")
                                     if s != ""]
        self.in_buffer = []
        #delete this channel from the global map, but don't close the socket
        #if global map is empty, asynchronous loop in server will finish
        self.del_channel()

