#
# pickleserver.py
# 
# Copyright (C) Jack Whitham 2008
# Released under the GNU GPL version 2.
#
# About:
#    Lightweight distributed computing platform for Twisted Python.
#    A task described by a dict containing (data_in, data_out) pairs
#    is processed in parallel by as many client machines as you have. The
#    data_in values are requested by each client from pickleserver.py,
#    and once processing is complete, a data_out value is sent back.
#
# Warning:
#    Not suitable for use on untrusted networks! pickle is used for
#    encoding and decoding network messages on the client and server, 
#    and pickle is not secure against code injection attacks. If you 
#    need to use it on the Internet, use a VPN or SSH to secure the
#    network connections.
# 
# How the server works:
# 1. Organise your computing job so that it can be carried out
#    by a single function call:
#      data_out = your_process(data_in)
#    (a) data_in must be an immutable type, e.g. a tuple, a string.
#    (b) Both data_in and data_out must be pickle-able.
#    If you want to turn an arbitrary object into an immutable type
#    you can use pickle to turn it into a string.
# 2. Create a dict() in which every key is one of the objects you 
#    want to use as a data_in, and every value is TODO (defined below).
# 3. Write the dict() to a disk file using pickle:
#    pickle.dump(your_dict, file("your_file.dat", "wb"))
# 4. Run "python pickleserver.py your_file.dat" to start the server.
# 5. The server periodically updates "your_file.dat" with the latest
#    version of the dict() object.
# 6. The server exits when all tasks are complete, i.e. nothing 
#    is TODO or IN_PROGRESS.
#
# Then, for the client:
# 1. Write a Python program based on the following model:
#
#   from twisted.internet import reactor, defer, protocol
#   from pickleserver import clientRun
#
#   def your_process(data_in):
#       data_out = data_in # or something more complicated...
#       return data_out
#
#   def run():
#       clientRun("localhost", your_process)
#
#   if ( __name__ == "__main__" ):
#       reactor.addSystemEventTrigger('after', 'startup', run)
#       reactor.run()
#
# 2. Change your_process to do whatever you need. 
# 3. Change "localhost" to the host name of the machine 
#    running pickleserver.py.
# 4. Run your Python program on one or more machines.
# 5. clientRun() requests a new job from the server, processes it,
#    and sends the results back. This process repeats as long as
#    the server is online.
# 


from twisted.internet import reactor, defer, protocol
import base64, zlib, pickle, collections, sys, traceback, signal

PORT_NUMBER = 1981
TODO = "Job To Do"
NO_JOBS = "No Jobs"
IN_PROGRESS = "Job In Progress"


def receiver(received_queue, data, complete_fn):
    if ( len(received_queue) > 0 ):
        last_ch = received_queue[ -1 ]
    else:
        last_ch = ""

    for ch in data:
        received_queue.append(ch)
        if (( ch == "\n" ) and ( last_ch == "\n" )):
            complete_fn()
            return
        last_ch = ch

def runJob(processor_fn, job):
    # SIGCHLD signal handler is unloaded so that
    # processes can be spawned normally even though
    # we are running in a Twisted environment.
    out = None
    unload_sigs = [ signal.SIGCHLD ]
    handlers = dict()
    for u in unload_sigs:
        handlers[ u ] = h = signal.getsignal(u)
        signal.signal(u, signal.SIG_DFL)

    try:
        try:
            out = processor_fn(job)
        except Exception, e:
            print 'While running job: %s' % repr(job)
            print 'Exception: %s' % repr(e)
            traceback.print_exc(file=sys.stdout)
    finally:
        for u in unload_sigs:
            signal.signal(u, handlers[ u ])
        
    return out


class Server(protocol.Protocol):
    def __init__(self, db, stopl):
        self.db = db
        self.stopl = stopl
        self.my_job_key = None
        self.received = collections.deque()

    def dataReceived(self, data):   
        if ( self.my_job_key == None ):
            return
        receiver(self.received, data, self.completedJob)

    def completedJob(self):
        assert ( self.my_job_key != None )

        rx_code = ''.join(self.received)
        out = None
        try:
            out = pickle.loads(zlib.decompress(base64.b64decode(rx_code)))
        except Exception, e:
            print 'Completion: job %s malformed' % repr(self.my_job_key)
            print 'More information: %s' % repr(e)
            self.db[ self.my_job_key ] = TODO

        if ( out != None ):
            print 'Completion: job %s done' % repr(self.my_job_key)
            self.db[ self.my_job_key ] = out
            self.checkJobs()

        self.my_job_key = None
        self.transport.loseConnection()

    def checkJobs(self):
        jobs_remaining = False
        todo = inp = done = 0
        for (key, data) in self.db.iteritems():
            if ( key != None ):
                if ( data == TODO ):
                    todo += 1
                elif ( data == IN_PROGRESS ):
                    inp += 1
                else:
                    done += 1

        print 'Status: %u jobs done, %u in progress, %u not started.' % (
                        done, inp, todo)

        if (( todo == 0 ) and ( inp == 0 )):
            print 'All jobs are done.'
            if ( len(self.stopl) != 0 ):
                self.stopl.pop().callback(True)

    def connectionLost(self, ignore=None):
        if ( self.my_job_key == None ):
            return
        print 'Disconnect: job %s not done' % repr(self.my_job_key)
        self.db[ self.my_job_key ] = TODO
        self.my_job_key = None

    def connectionMade(self):
        # Find a job that needs doing
        assert self.my_job_key == None
        for (key, data) in self.db.iteritems():
            if (( key != None ) and ( data == TODO )):
                self.my_job_key = key
                break

        if ( self.my_job_key == None ):
            print 'Connection: no jobs to be done at this time.'
            send_key = NO_JOBS
            self.checkJobs()
        else:
            print 'Connection: issued job %s' % repr(self.my_job_key)
            send_key = self.my_job_key
            self.db[ self.my_job_key ] = IN_PROGRESS
       
        send_code = base64.b64encode(zlib.compress(pickle.dumps(send_key), 9))
        self.transport.write(send_code + "\n\n")
    
        if ( send_key == NO_JOBS ):
            self.transport.loseConnection()

class Client(protocol.Protocol):
    def __init__(self, processor_fn, next, stopl):
        self.processor_fn = processor_fn
        self.next = next
        self.stopl = stopl
        self.received = collections.deque()
        self.key_received = False
        self.job_done = False

    def dataReceived(self, data):   
        if ( self.key_received ):
            return
        receiver(self.received, data, self.startJob)

    def startJob(self):
        assert not self.key_received
        assert not self.job_done
        self.key_received = True

        rx_code = ''.join(self.received)
        job = None
        try:
            job = pickle.loads(zlib.decompress(base64.b64decode(rx_code)))
        except Exception, e:
            print 'New job: key is malformed!'
            print 'More information: %s' % repr(e)
            traceback.print_exc(file=sys.stdout)
            self.transport.loseConnection()
            if ( len(self.stopl) != 0 ):
                self.stopl.pop().callback(True)
            return

        if ( job == NO_JOBS ):
            print 'No jobs are available at this time.'
            print 'Try again in 1 minute...'
            self.transport.loseConnection()

            def tryAgain():
                self.next.callback(True)

            reactor.callLater(60.0, tryAgain)
            return

        print 'New job: %s' % repr(job)
        out = runJob(self.processor_fn, job)
        if ( out == None ):
            self.transport.loseConnection()
            if ( len(self.stopl) != 0 ):
                self.stopl.pop().callback(True)
            return

        self.job_done = True
        print 'Job completed, sending results.'
        send_code = base64.b64encode(zlib.compress(pickle.dumps(out), 9))
        self.transport.write(send_code + "\n\n")
        self.transport.loseConnection()
        self.next.callback(True)

    def connectionLost(self, ignore=None):
        if ( not self.key_received ):
            print 'Disconnect: server down?'
            if ( len(self.stopl) != 0 ):
                self.stopl.pop().callback(True)

    def connectionMade(self):
        print 'Connected, waiting for work.'

class ServerFactory(protocol.ServerFactory):
    protocol = None

    def __init__(self, *p):
        self.params = p

    def buildProtocol(self, addr):
        proto = Server(*self.params)
        return proto

class ClientFactory(protocol.ClientFactory):
    protocol = None

    def __init__(self, *p):
        self.params = p
        (processor_fn, next, self.stopl) = p

    def buildProtocol(self, addr):
        proto = Client(*self.params)
        return proto

    def clientConnectionFailed(self, connector, reason):
        print "Client connection failed: %s" % repr(reason)
        if ( len(self.stopl) != 0 ):
            self.stopl.pop().callback(True)

def serverArgvRun():
    ok = False
    try:
        if ( len(sys.argv) != 2 ):
            print "Usage: %s <pickle database>" % sys.argv[ 0 ]

        else:
            db_file = sys.argv[ 1 ]
            print 'Database:', db_file
            fd = file(db_file, "rb")
            db = pickle.load(fd)
            assert type(db) == dict
            assert len(db) > 0
            ok = True

    except Exception, e:
        print "Exception: %s" % repr(e)

    if ( not ok ):
        reactor.stop()
        return

    serverRun(db_file, db)

def serverRun(db_file, db):
    for (key, data) in db.iteritems():
        if ( data == IN_PROGRESS ):
            db[ key ] = TODO

    def writeDatabase():
        if ( len(db) != 0 ):
            pickle.dump(db, file(db_file, "wb"))

    def periodic():
        reactor.callLater(60, periodic)
        writeDatabase()
       
    def fnStop(k=None):
        writeDatabase()
        print 'Clean exit'
        db.clear()
        reactor.stop()
   
    print 'Service port: %u' % PORT_NUMBER
    stop = defer.Deferred()
    sf = ServerFactory(db, [stop])
    reactor.listenTCP(PORT_NUMBER, sf)
    reactor.callLater(60, periodic)
    stop.addCallback(fnStop)

def clientRun(hostname, processor_fn):
    stop = defer.Deferred()

    def fnContinue(k=None):
        next = defer.Deferred()
        next.addCallback(fnContinue)

        reactor.connectTCP(hostname, PORT_NUMBER, 
                ClientFactory(processor_fn, next, [stop]))

    def fnStop(k=None):
        reactor.stop()

    stop.addCallback(fnStop)
    fnContinue()

if ( __name__ == "__main__" ):
    reactor.addSystemEventTrigger('after', 'startup', serverArgvRun)
    reactor.run()



