#!/usr/bin/python -tt
#
# Copyright 2012 Red Hat, Inc.
# License: GPLv2+
# Author: Toshio Kuratomi <toshio@fedoraproject.org>

__version_info__ = ((0, 2),)

import os
import grp
import sys
import argparse

from getpass import getpass, getuser

from configobj import ConfigObj, flatten_errors
from fedora.client import AccountSystem, AuthError
from validate import Validator

retries = 0
MAX_RETRIES = 3

def parse_commandline(args):
    parser = argparse.ArgumentParser(
            description='Generate an ssh authorized_keys file using fas')
    parser.add_argument('users', metavar='USERS', type=str, nargs='+',
            help='FAS usernames and FAS group names preceeded by "@".'
            ' For example: toshio skvidal @gitfas')
    parser.add_argument('--with-fas-conf', dest='fas_conf',
            action='store_true', default=True,
            help='Look for credentials in /etc/fas.conf.  On hosts that run'
            ' fasClient to set up users and groups, fas.conf will have'
            ' a username and password that will work for us.  This can only be'
            ' read if the user is root.  Otherwise a different method of'
            ' getting credentials will be used.')
    parser.add_argument('--without-fas-conf', dest='fas_conf',
            action='store_false', default=True,
            help='Look for credentials in /etc/fas.conf.  On hosts that run'
            ' fasClient to set up users and groups, fas.conf will have'
            ' a username and password that will work for us.  This can only be'
            ' read if the user is root.  Otherwise a different method of'
            ' getting credentials will be used.')

    # The following can be set in the config files so the default value is
    # taken from there
    parser.add_argument('--limit-from', '-l', action='append', default=None,
            help='IP addresses that limit where the keys can be used to'
            ' ssh from')
    parser.add_argument('--with-fas-groups', dest='fas_groups',
            action='store_true', default=None,
            help='Query fas for members of groups.  The default is to use the'
            ' groups information on the local machine.  Querying fas can be'
            ' used on machines that do not get their group information from'
            ' fas or if you need group information that has been updated'
            ' recently but it is slower.')
    parser.add_argument('--without-fas-groups', dest='fas_groups',
            action='store_false', default=None,
            help='Query fas for members of groups.  The default is to use the'
            ' groups information on the local machine.  Querying fas can be'
            ' used on machines that do not get their group information from'
            ' fas or if you need group information that has been updated'
            ' recently but it is slower.')
    parser.add_argument('--username', '-u', type=str, action='store',
            default=None,
            help='Use this username when contacting fas.  This will not be'
            ' used if a username and password can be read from /etc/fas.conf')
    parsed = parser.parse_args(args)

    args_dict = {}
    args_dict['fas_conf'] = parsed.fas_conf
    if parsed.limit_from is not None:
        args_dict['limit_from'] = parsed.limit_from
    if parsed.fas_groups is not None:
        args_dict['fas_groups'] = parsed.fas_groups
    if parsed.username is not None:
        args_dict['username'] = parsed.username

    users = set()
    groups = set()
    for entry in parsed.users:
        if entry.startswith('@'):
            groups.add(entry[1:])
        elif entry == sys.argv[0]:
            # argparse hasn't already subtracted the program name here....
            continue
        else:
            users.add(entry)

    args_dict['users'] = users
    args_dict['groups'] = groups

    return args_dict

def read_config_files(cfg_files):

    config_spec = (
            'limit_from = ip_addr_list(default=list())',
            'fas_groups = boolean(default=False)',
            'username = string(default=%s)' % getuser(),
            )

    options = ConfigObj(configspec=config_spec)
    validator = Validator()

    for cfg_file in cfg_files:
        cfg_file = os.path.abspath(os.path.expanduser(cfg_file))
        cfg = ConfigObj(cfg_file, configspec=config_spec)
        options.merge(cfg)

    results = options.validate(validator)
    if results != True:
        for (section_list, key, unused_) in flatten_errors(options, results):
            if key is not None:
                print 'The "%s" key in the section "%s" failed validation' % (
                        key, ', '.join(section_list))
            else:
                print 'The following section was missing:%s ' % ', '.join(
                        section_list)
        sys.exit(1)

    return options

def read_fas_conf():
    '''Read username and password from /etc/fas.conf

    This function will trhow an exception if it can't read the information.
    otherwise it returns username, password
    '''
    cfgfile = open('/etc/fas.conf', 'r')
    # Turn comment markers ";" into "#"
    change_comment = lambda line: (line.startswith(';')
            and line.replace(';', '#', 1) ) or line
    lines = [change_comment(l) for l in cfgfile]
    cfg = ConfigObj(lines)
    return cfg['global']['login'], cfg['global']['password']

def members_of_group(group, use_fas=False):
    if use_fas:
        members = retry_fas(fas.group_members, group)
        members = [m.username for m in members]
    else:
        group_info = grp.getgrnam(group)
        members = group_info[3]

    return set(members)

def iterate_ssh_keys(user):
    person = retry_fas(fas.people_by_key, search=user, fields=['ssh_key'])[user]
    if person.ssh_key:
        for key in person.ssh_key.split('\n'):
            if key:
                yield key

def retry_fas(function, *args, **kwargs):
    global retries
    # Try the function at least once
    while True:
        try:
            return function(*args, **kwargs)
        except AuthError:
            retries += 1
            password = getpass('FAS Password for %s:' % function.im_self.username)
            function.im_self.password = password
            if retries >= MAX_RETRIES:
                raise

if __name__ == '__main__':
    args = parse_commandline(sys.argv)

    conf = read_config_files(['/etc/fas-auth-keys.conf',
        '~/.fedora/fas-auth-keys.conf'])

    # Merge args and conf.  Note that this requires that we've already parsed
    # our args so that they're in a dict and won't overwrite pieces of conf if
    # they weren't specified on the commandline
    conf.merge(args)

    # Merge in fas.conf
    if conf['fas_conf']:
        try:
            username, password = read_fas_conf()
        except:
            # We don't care -- this is a nicety
            pass
        else:
            conf['username'] = username
            conf['password'] = password

    fas = AccountSystem(username=conf['username'], password=conf['password'])

    from_string = ''
    if conf['limit_from']:
        from_string = 'from="%s" ' % ','.join(conf['limit_from'])

    for group in conf['groups']:
        conf['users'].update(members_of_group(group, use_fas=conf['fas_groups']))

    ssh_keys = dict()
    for user in conf['users']:
        ssh_keys[user] = set()
        for key in iterate_ssh_keys(user):
            ssh_keys[user].add(key)

    for user in sorted(ssh_keys.keys()):
        for key in ssh_keys[user]:
            print '%s%s' % (from_string, key)