# ssh daemon, ofzo
# Wijnand 'tehmaze' Modderman - http://tehmaze.com
# BSD License
#
# TODO:
#  * tab completion
#  * history search?
#  * proper relaying, probably by implementing a `sshbot` class
#
# BUGS/LIMITATIONS:
#  * relay only relays to jabber/irc and not back, due to framework limitations
#  * only default VT100/ANSI capable terminals are supported
#

from gozerbot.commands import cmnds
from gozerbot.config import config
from gozerbot.datadir import datadir
from gozerbot.fleet import fleet
from gozerbot.generic import handle_exception, geturl2, rlog, waitforqueue
from gozerbot.ircevent import Ircevent
from gozerbot.monitor import outmonitor
from gozerbot.pdol import Pdol
from gozerbot.persistconfig import PersistConfig
from gozerbot.plugins import plugins
from gozerbot.users import users
import gozerbot.thr as thr
import base64
from binascii import hexlify
import os
import Queue
import socket
import threading
import time
import paramiko as ssh

cfg = PersistConfig()
cfg.define('port', 2200)
cfg.define('host', '127.0.0.1')
cfg.define('keyfile', 'gozerbot.rsa.key')
cfg.define('motd', 'gozerbot.motd')

sshd = None

def init():
    global sshd
    try:
        sshd = Sshd()
        time.sleep(5)
        sshd.start()
    except:
        handle_exception()
        return 0
    return 1

def shutdown():
    global sshd
    try:
        sshd.stop()
    except:
        handle_exception()
        return 0
    return 1

class SshIdentity(Pdol):
    '''I manage what keys belong to who.'''
    def __init__(self):
        Pdol.__init__(self, os.path.join(datadir, 'ssh-identities'))

    def add(self, user, key):
        if self.data.has_key(user):
            if not key in self.data[user]:
                self.data[user].append(key)
        else:
            self.data[user] = [key]
        self.save()

    def add_rsa(self, user, b64key):
        key = ssh.RSAKey(data=base64.decodestring(b64key))
        self.add(user, key)

    def check(self, user, key):
        if not self.data.has_key(user):
            return False
        return key in self.data[user]

class SshClient(object):
    '''Here we can store all client-related information.'''

    bot     = None
    buffer  = ''
    cursor  = 0
    index   = 0
    history = []
    monitor = False
    watch   = 'ssh'

    def __init__(self):
        if config['jabberenable']:
            self.bot = fleet.byname('jabbermain')
        else:
            self.bot = fleet.byname('main')

    def prompt(self, text='', **options):
        defaults = {'reset': True, 'cursor': None}
        defaults.update(options)
        if defaults['reset']:
            self.channel.send('\r\x1b[K')
        prompt = '%s@gozerbot [%s/%s]> ' % (self.transport.get_username(),
            self.bot.name, self.watch)
        self.channel.send(prompt)
        if defaults['cursor']:
            back = len(text) - defaults['cursor']
        else:
            back = len(text) - self.cursor
        self.channel.send(text)
        if back:
            self.channel.send('\x1b[%dD' % back)

    def reply(self, text):
        if hasattr(self, 'channel'):
            self.channel.send(text)
        else:
            raise ssh.SSHException('No channel to send text to.')

class SshKey(ssh.RSAKey):
    '''(Server) RSA key.'''
    def __init__(self, filename):
        if not os.path.isfile(filename):
            pass # create it or something?
        else:
            ssh.RSAKey.__init__(self, filename=filename)

server_key = SshKey(cfg.get('keyfile'))
user_keys = SshIdentity()

class SshServer(ssh.ServerInterface):
    '''The ssh server, very boring, I check authentication, session request en pty changes.'''
    def __init__(self):
        self.event = threading.Event()

    def check_channel_request(self, kind, chanid):
        if kind == 'session':
            return ssh.OPEN_SUCCEEDED
        return ssh.OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED

    def check_auth_publickey(self, username, key):
        if user_keys.check(username, key):
            return ssh.AUTH_SUCCESSFUL
        return ssh.AUTH_FAILED

    def get_allowed_auths(self, username):
        return 'password,publickey'

    def check_channel_shell_request(self, channel):
        self.event.set()
        return True

    def check_channel_pty_request(self, channel, term, width, height, pixelwidth,
        pixelheight, modes):
        return True

class Sshd:
    '''The ssh dispatcher and connection handler.'''

    def __init__(self):
        self.run = False

    def start(self):
        self.run = True
        self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        try:
            self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        except:
            rlog(5, 'sshd', 'warning, SO_REUSEADDR not available on this platform')
        try:
            self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
        except:
            rlog(5, 'sshd', 'warning, SO_REUSEPORT not available on this platform')
        self.socket.setblocking(1)
        self.socket.settimeout(1)
        retries = 5
        while retries:
            try:
                self.socket.bind((cfg.get('host'), cfg.get('port')))
                break
            except socket.error, e:
                if e[0] == 98: # 'Address already in use'
                    if retries == 0:
                        raise
                    retries -= 1
                    rlog(5, 'sshd', 'Address already in use, %d retries left' % retries)
                    time.sleep(2)
        self.socket.listen(100)
        thr.start_new_thread(self.loop, ())

    def stop(self):
        self.run = False
        self.socket.close()

    def loop(self):
        while self.run:
            try:
                client, addr = self.socket.accept()
                thr.start_new_thread(self.handle_client, (client, addr))
            except socket.timeout:
                pass

    def handle_client(self, client, addr):
        rlog(5, 'sshd', 'connection from %s' % str(addr))
        sshclient = SshClient()
        sshclient.addr = addr
        sshclient.client = client
        sshclient.transport = ssh.Transport(client)
        try:
            sshclient.transport.load_server_moduli()
        except:
            rlog(5, 'sshd', '%s: failed to load moduli -- no gex support' % str(addr))
        sshclient.transport.add_server_key(server_key)
        sshclient.server = SshServer()
        try:
            sshclient.transport.start_server(server=sshclient.server)
        except ssh.SSHException, e:
            rlog(5, 'sshd', '%s: negotiation failed' % str(sshclient.addr))
            self.drop_client(sshclient)
            return
        sshclient.channel = sshclient.transport.accept(50)
        if sshclient.channel is None:
            rlog(5, 'sshd', '%s: no channel requested' % str(sshclient.addr))
            self.drop_client(sshclient)
            return
        rlog(5, 'sshd[%s]' % sshclient.channel.get_id(), 'authenticated connection from %s:%d' % tuple(sshclient.addr))
        # feed username@ip to the users object
        sshclient.userhost = '%s@%s' % (sshclient.transport.get_username(), sshclient.addr[0])
        if not users.getname(sshclient.userhost):
            users.adduserhost(sshclient.transport.get_username(), sshclient.userhost)
        sshclient.server.event.wait(10)
        if not sshclient.server.event.isSet():
            rlog(5, 'sshd[%s]' % sshclient.channel.get_id(), 'no shell requested')
            self.drop_client(sshclient)
            return
        rlog(5, 'sshd[%s]' % sshclient.channel.get_id(), 'starting session with %s:%d' % tuple(sshclient.addr))
        sshclient.reply('\r\nWelcome to %s\r\n' % config['version'])
        if os.path.isfile(cfg.get('motd')):
            sshclient.reply(open(cfg.get('motd'), 'r').read().replace('\n', '\r\n'))
        sshclient.reply('\r\nThis terminal has very limited VT100/ANSI capabilities\r\n')
        sshclient.reply('Use /command to control the terminal (see /help)\r\n')
        sshclient.reply('Use !command to dispatch a command to the bot\r\n')
        self.loop_client(sshclient)

    def loop_client(self, sshclient):
        input = sshclient.channel.makefile('rU')
        escape = False
        sequence = ''

        # fake join event
        self.fake_join(sshclient)

        # send prompt to client
        sshclient.prompt()
        while self.run:
            c = sshclient.channel.recv(1)
            if len(c) == 0:
                continue
            #rlog(5, 'sshd', 'DEBUG: %s (%d)' % (c, ord(c)))
            if c == KEYMAP['EOT']: # ^D
                if len(sshclient.buffer) > 0:
                    sshclient.buffer = ''
                    sshclient.cursor = 0
                    sshclient.prompt()
                else:
                    sshclient.reply('\r\nGoodbye\r\n')
                    self.fake_quit(sshclient)
                    channel.close()
                    return
            elif c == KEYMAP['LF']: # ^J
                pass # ignore
            elif c == KEYMAP['FF']: # ^L (Redraw screen)
                sshclient.prompt(sshclient.buffer)
            elif c == KEYMAP['CR']: # ^M (Return)
                sshclient.channel.send('\r\n')
                sshclient.history.append(sshclient.buffer)
                sshclient.index = len(sshclient.history)-1
                result = []
                try:
                    if sshclient.buffer.startswith('/'):
                        result = self.handle_sshd_command(sshclient, sshclient.buffer[1:])
                    elif sshclient.buffer.startswith('!'):
                        result = self.handle_command(sshclient, sshclient.buffer[1:])
                    else:
                        self.handle_chat(sshclient, sshclient.buffer)
                    for line in result:
                        sshclient.reply('%s\r\n' % line)
                except Exception, e:
                    handle_exception()
                    sshclient.reply('Exception: %s\r\n' % str(e))
                sshclient.buffer = ''
                sshclient.cursor = 0
                sshclient.prompt()
            elif c == KEYMAP['ETB']: # ^W (Delete word)
                if ' ' in sshclient.buffer:
                    sshclient.buffer = ' '.join(sshclient.buffer.split(' ')[:-1])
                    sshclient.cursor = len(sshclient.buffer)
                else:
                    sshclient.buffer = ''
                    sshclient.cursor = 0
                sshclient.prompt(sshclient.buffer)
            elif c == KEYMAP['BS']: # Backspace
                if len(sshclient.buffer) > 0:
                    if sshclient.cursor == len(sshclient.buffer):
                        sshclient.buffer = sshclient.buffer[:-1]
                    else:   
                        sshclient.buffer = sshclient.buffer[0:sshclient.cursor-1] + sshclient.buffer[sshclient.cursor:]
                    sshclient.cursor -= 1
                    sshclient.prompt(sshclient.buffer)
            elif c == KEYMAP['ESC']: # Escape
                escape = True
                sequence = ''
            elif escape:
                print [v[len(sequence)] for (k, v) in SEQMAP.iteritems() if v.startswith(sequence)]
                if c in [v[len(sequence)] for (k, v) in SEQMAP.iteritems() if v.startswith(sequence)]:
                    sequence += c
                    if sequence in SEQMAP.values():
                        for (k, v) in SEQMAP.iteritems():
                            if v == sequence:
                                print 'SEQ', k
                    if sequence == SEQMAP['up']:
                        sequence = '' ; escape = False
                        if not sshclient.history:
                            continue
                        sshclient.buffer = sshclient.history[sshclient.index]
                        sshclient.cursor = len(sshclient.buffer)
                        sshclient.prompt(sshclient.buffer)
                        if sshclient.index > 0:
                            sshclient.index -= 1
                    elif sequence == SEQMAP['down']: # cursor down
                        sequence = '' ; escape = False
                        if not sshclient.history:
                            continue
                        if sshclient.index < (len(sshclient.history)-1):
                            sshclient.index += 1
                        sshclient.buffer = sshclient.history[sshclient.index]
                        sshclient.cursor = len(sshclient.buffer)
                        sshclient.prompt(sshclient.buffer)
                    elif sequence == SEQMAP['right']: # cursor forward
                        sequence = '' ; escape = False
                        if sshclient.cursor < len(sshclient.buffer):
                            sshclient.cursor += 1
                        sshclient.prompt(sshclient.buffer)
                    elif sequence == SEQMAP['left']: # cursor backward
                        sequence = '' ; escape = False
                        if sshclient.cursor > 0: 
                            sshclient.cursor -= 1
                        sshclient.prompt(sshclient.buffer)
                    elif sequence == SEQMAP['home']:
                        sequence = '' ; escape = False
                        sshclient.cursor = 0
                        sshclient.prompt(sshclient.buffer)
                    elif sequence == SEQMAP['end']:
                        sequence = '' ; escape = False
                        sshclient.cursor = len(sshclient.buffer)
                        sshclient.prompt(sshclient.buffer)
                    #else:
                    #    rlog(5, 'sshd', 'unhandled sequence %s' % sequence)
                else:
                    # reset
                    sequence = ''
                    escape = False # end of escape sequence
            else:
                if sshclient.cursor == len(sshclient.buffer):
                    sshclient.channel.send(c)
                    sshclient.buffer += c
                    sshclient.cursor += 1    
                else:
                    sshclient.buffer = sshclient.buffer[0:sshclient.cursor] + c + sshclient.buffer[sshclient.cursor:]
                    sshclient.cursor += 1
                    sshclient.prompt(sshclient.buffer) 
        sshclient.channel.close()

    def cook_ievent(self, sshclient):
        ievent = Ircevent()
        ievent.txt = buffer
        ievent.nick = sshclient.transport.get_username()
        ievent.userhost = sshclient.userhost
        ievent.channel = sshclient.watch
        q = Queue.Queue()
        ievent.queues.append(q)
        ievent.speed = 3
        ievent.bot = sshclient.bot
        return ievent

    def fake_join(self, sshclient):
        ievent = self.cook_ievent(sshclient)
        ievent.channel = sshclient.watch
        ievent.cmnd = 'JOIN'
        sshclient.bot.handle_ievent(ievent)

    def fake_part(self, sshclient):
        ievent = self.cook_ievent(sshclient)
        ievent.channel = sshclient.watch
        ievent.cmnd = 'PART'
        sshclient.bot.handle_ievent(ievent)

    def fake_quit(self, sshclient):
        ievent = self.cook_ievent(sshclient)
        ievent.channel = sshclient.watch
        ievent.txt = '' 
        ievent.cmnd = 'QUIT'
        sshclient.bot.handle_ievent(ievent)

    def handle_chat(self, sshclient, buffer):
        ievent = self.cook_ievent(sshclient)
        ievent.channel = sshclient.watch
        ievent.txt = buffer
        ievent.origtxt = ievent.txt
        ievent.cmnd = 'PRIVMSG'
        sshclient.bot.handle_ievent(ievent)

    def handle_command(self, sshclient, buffer):
        ievent = self.cook_ievent(sshclient)
        ievent.txt = buffer
        result = []
        if plugins.woulddispatch(sshclient.bot, ievent):
            thr.start_new_thread(plugins.trydispatch, (sshclient.bot, ievent))
        else:
            return ["can't dispatch %s" % buffer, ]
        result = waitforqueue(ievent.queues[-1], 60)
        if not result:
            return ["can't dispatch %s" % buffer, ]
        return result

    def handle_sshd_command(self, sshclient, buffer):
        command = buffer
        args = ''
        if ' ' in buffer:
            command, args = buffer.split(' ', 1)
        command = command.lower()
        hook = getattr(self, 'command_%s' % command, None)
        if hook:
            return hook(sshclient, args)
        else:
            return ['command not found', ]

    def command_help(self, sshclient, args):
        '''[<command>] - Shows help on <command> or all available commands'''
        if not args:
            reply = ['Use /command to control the terminal (see /help)', 'Use !command to dispatch a command to the bot', '']
            for hook in dir(self):
                if hook.startswith('command_'):
                    name = hook[8:]
                    hook = getattr(self, hook, False)
                    if hook:
                        reply.append('    %s %s' % (name, hook.__doc__))
            return reply
        else:
            command = args.lower()
            hook = getattr(self, 'command_%s' % command, None)
            if hook:
                return ['%s: %s' % (command, hook.__doc__), ]
            else:
                return ['command not found', ]
        return ['not implemented', ]

    def command_me(self, sshclient, args):
        '''<text> - Simulate an IRC action'''
        ievent = self.cook_ievent(sshclient)
        ievent.channel = sshclient.watch
        ievent.txt = '\001ACTION %s\001' % args
        ievent.origtxt = ievent.txt
        sshclient.bot.handle_ievent(ievent)
        return []

    def command_server(self, sshclient, args):
        '''<bot name> - Set current server to fleet bot <bot name>'''
        args = args.split()
        if len(args) != 1:
            return self.command_help(sshclient, 'server')
        bot = fleet.byname(args[0])
        if bot:
            sshclient.bot = bot
            return ['ok', ]
        return ['no such bot', ]

    def command_monitor(self, sshclient, args):
        '''<on/off> - Toggle the channel monitor'''
        args = args.split()
        if len(args) != 1 or args[0].lower() not in ['on', 'off']:
            return self.command_help(sshclient, 'monitor')
        sshclient.monitor = args[0].lower() == 'on'
        return ['monitor %s' % args[0].lower(), ]

    def command_watch(self, sshclient, args):
        '''<channel> - Set current channel to <channel>'''
        args = args.split()
        if len(args) != 1:
            return self.command_help(sshclient, 'watch')
        sshclient.watch = args[0].lower()
        return ['ok', ]

    def command_quit(self, sshclient, args):
        '''- Disconnect'''
        sshclient.reply('Goodbye\r\n')
        self.drop_client(sshclient)

    def drop_client(self, sshclient):
        if hasattr(sshclient, 'channel'):
            rlog(5, 'sshd[%d]' % sshclient.channel.get_id(), 'dropping connection to %s:%d' % tuple(sshclient.addr))
        else:
            rlog(5, 'sshd', 'dropping connection to %s:%d' % tuple(sshclient.addr))
        if hasattr(sshclient, 'bot'):
            self.fake_quit(sshclient)
        try:
            sshclient.client.close()
        except socket.error:
            pass

def handle_sshdinstallkey(bot, ievent):
    if len(ievent.args) != 2:
        ievent.missing('<user> <url>')
        return
    if not users.exist(ievent.args[0]):
        ievent.reply('invalid user')
        return
    data = geturl2(ievent.args[1]).splitlines()
    keys = []
    result = []
    lineno = 1
    for line in data:
        if line.startswith('ssh-dss'):
            result.append('skipped DSA key on line %d (not supported)' % lineno)
        elif line.startswith('ssh-rsa'):
            keys.append(line.split()[1])
        else:
            result.append('invalid data on line %d' % lineno)
        lineno += 1
    for key in keys:
        user_keys.add_rsa(ievent.args[0], key)
    result.insert(0, 'imported %d RSA keys' % len(keys))
    ievent.reply(' .. '.join(result))

cmnds.add('sshd-installkey', handle_sshdinstallkey, 'OPER')

KEYMAP = {
    'ETX':  '\x03', # ^C - End of Text
    'EOT':  '\x04', # ^D - End of Transmission
    'BS':   '\x08', # ^H - Backspace
    'TAB':  '\x09', # ^I - Horizontal Tab
    'LF':   '\x0a', # ^J - Line Feed
    'FF':   '\x0c', # ^L - Form Feed
    'CR':   '\x0d', # ^M - Carriage Return
    'ETB':  '\x17', # ^W - End of Transmission Block
    'ESC':  '\x1b', # ^[ - Escape
}

SEQMAP = {
    'up':           '[A',
    'down':         '[B',
    'right':        '[C',
    'left':         '[D',
    'home':         '[H',
    'end':          '[F',
    'pageup':       '[5~',
    'pagedown':     '[6~',
    'f1':           'OP',
    'f2':           'OQ',
    'f3':           'OS',
}

