#!/usr/bin/python -tt

import sys, os
import syslog
import pwd
import syslog
import glob
import base64
import commands
import re
import selinux
import subprocess
import itertools

EXT_LIB = "/usr/libexec/openshift/lib/util"

commands_map = {
    "git-receive-pack": "/usr/bin/git-receive-pack",
    "git-upload-pack": "/usr/bin/git-upload-pack",
    "snapshot": "/bin/bash",
    "restore": "/bin/bash",
    "tail": "/usr/bin/tail",
    "rhcsh": "/bin/bash",
    "true": "/bin/true",
    "java": "/bin/bash",
    "scp": "/bin/bash",
    "cd": "/bin/bash",
    "set": "/bin/bash",
    "mkdir": "/bin/bash",
    "test": "/bin/bash",
    "rsync": "/bin/bash",
    "ctl_all": "/bin/bash",
    "deploy.sh": "/bin/bash",
    "rhc-list-ports": "/bin/bash",
    "post_deploy.sh": "/bin/bash",
    "quota": "/usr/bin/quota"
}

comment_re = re.compile("#.*$")

# These should come from somewhere, not be hard coded - MAL
openshift_cgroup_subsystems="cpu,cpuacct,memory,net_cls,freezer"

def load_env(directory):
    for entry in glob.glob(os.path.expanduser(os.path.join(directory, '*'))):
        env = os.path.basename(entry)
        if os.path.isdir(entry):
            load_env(os.path.join(directory, entry))
            continue

        with open(entry, 'r') as file:
            contents = file.read().rstrip()

            if contents.startswith('export '):
                value = contents.split('=')[1].strip('\'"')
                os.environ[env] = value
            else:
                os.environ[env] = contents


def gear_env():
    load_env('/etc/openshift/env')
    system_path = os.environ['PATH']

    load_env('~/.env/.uservars') # TODO: remove when V1 goes away
    load_env('~/.env')
    load_env('~/*/env/')

    primary = ''
    if 'OPENSHIFT_PRIMARY_CARTRIDGE_DIR' in os.environ:
        primary = os.path.basename(os.environ['OPENSHIFT_PRIMARY_CARTRIDGE_DIR'].rstrip('/'))
        load_env(os.path.join(os.environ['OPENSHIFT_PRIMARY_CARTRIDGE_DIR'], 'env'))

    primary_path = "OPENSHIFT_%s_PATH_ELEMENT" % primary
    matcher = re.compile('^OPENSHIFT_.*_PATH_ELEMENT$')
    path_elements = [key for key in os.environ if matcher.match(key)]

    if primary_path in path_elements:
        path_elements.remove(primary_path)
    elements = [os.environ[key] for key in path_elements]

    if primary_path in os.environ:
        elements.insert(0, os.environ[primary_path])

    if 'PATH' in os.environ:
        elements.append(os.environ['PATH'])
    elements.append(system_path)

    os.environ['PATH'] = ':'.join(elements)


# Calling oo-get-mcs-level gets into trouble due to BZ 957257.  Use an
# internal call which produces the same logic.
def get_mcs_level(uid):
    set_size = 1024
    group_size = 2
    uid_offset = 0
    mls_num = 0

    mcs_set = [ "c%d" % n for n in xrange(set_size) ]
    iuid = uid_offset

    for mcs_col in itertools.combinations(mcs_set, 2):
        iuid+=1
        if iuid == uid:
            return "s%d:%s" % ( mls_num, ",".join(mcs_col) )


def read_config():
  config = {}
  f = open('/etc/openshift/node.conf','r')
  data = f.read()
  f.close()
  lines = data.split("\n")
  for line in lines:
    clean_line = comment_re.sub("", line) # remove comments
    clean_line = clean_line.strip() # remove leading and trailing white space
    if clean_line != "":
      split_line = clean_line.split("=")
      if len(split_line) != 2:
          syslog.syslog("node config error: %s" % (line))
          sys.stderr.write("Error in node configuration")
          sys.exit(2)  # need to set the proper exit code
      value = split_line[1].strip('\'"') # remove quotes from value strings
      config[split_line[0]] = value
  return config

#
# Join the user's cgroup if available
#
def join_cgroup():
    """
    Determine a user's cgroup and join it if possible
    """

    username = pwd.getpwuid(os.getuid())[0]
    cgpath = "/openshift/%s" % username
    pid = os.getpid()

    cmd_template = "cgclassify -g %s:%s %d"
    cmd = cmd_template % (openshift_cgroup_subsystems, cgpath, pid)
    syslog.syslog("user %s: putting process %d in cgroups %s" % (username, pid, openshift_cgroup_subsystems))

    retval = subprocess.call(cmd.split())
    if retval != 0:
        syslog.syslog("user %s: cgroup classification failed: retval = %d" % (username, retval))

    # should raise an exception? MAL

if __name__ == '__main__':
    # first self-apply restrictions
    # join_cgroup()
    config = read_config()
    gear_env()

    orig_cmd = os.environ.get('SSH_ORIGINAL_COMMAND', "rhcsh")
    syslog.syslog(orig_cmd)
    allargs = orig_cmd.split()
    try:
        basecmd = os.path.basename(allargs[0])
        cmd = commands_map[basecmd]
    except:
        # Catch all, just run the command as is via bash.
        cmd = "/bin/bash"
        str = ' '.join(allargs)
        allargs = ['-c', str]
    if basecmd in ('snapshot',):
        # This gets called with "snapshot"
        allargs = ['oo-snapshot']
    if basecmd in ('restore',):
        # This gets called with "restore <INCLUDE_GIT>"
        include_git = False
        if len(allargs) > 1 and allargs[1] == 'INCLUDE_GIT':
            include_git = True

        allargs = ['oo-restore']
        if include_git:
            allargs.append('INCLUDE_GIT')
    elif basecmd in ('rhcsh',):
        os.environ["PS1"] = "rhcsh> "
        if len(allargs) < 2:
            allargs = ['--init-file', '/usr/bin/rhcsh', '-i']
        else:
            str = ' '.join(allargs[1:])
            allargs = ['--init-file', '/usr/bin/rhcsh', '-c', str]
    elif basecmd in ('ctl_all',):
        allargs = ['-c', '. /usr/bin/rhcsh > /dev/null ; ctl_all %s' % allargs[-1]]
    elif basecmd in ('java','set','scp', 'cd', 'test', 'mkdir', 'rsync', 'deploy.sh', 'post_deploy.sh', 'rhc-list-ports'):
        str = ' '.join(allargs)
        allargs = ['-c', str]
    elif basecmd in ('tail',):
        files = []

        files_start_index = 1
        args = []
        add_follow = True
        if allargs[1] == '--opts':
            files_start_index = 3
            args_str = base64.standard_b64decode(allargs[2])
            args = args_str.split()
            for arg in args:
                if arg.startswith(('..', '/')):
                    print "All paths must be relative: " + arg
                    sys.exit(88)
                elif arg == '-f' or arg == '-F' or arg.startswith('--follow'):
                    add_follow = False

        for glob_list in allargs[files_start_index:]:
            for f in glob.glob(glob_list):
                try:
                    if os.path.islink(f) and os.path.lexists(f):
                        files.append(f)
                    else:
                        files.append(f)
                except OSError, e:
                    print "Error: %s" % e.strerror
                    sys.exit(91)
        if len(files) == 0:
            print "Could not find any files matching glob"
            sys.exit(32)
        allargs = []
        allargs.extend(args)
        if add_follow:
            allargs.append('-f')
        allargs.extend(files)
    elif basecmd in ('git-receive-pack', 'git-upload-pack'):
        # git repositories need to be parsed specially
        thearg = ' '.join(allargs[1:])
        if thearg[0] == "'" and thearg[-1] == "'":
            thearg = thearg.replace("'","")
        thearg = thearg.replace("\\'", "")
        thearg = thearg.replace("//", "/")

        # replace leading tilde (~) with user's home path
        realpath = os.path.expanduser(thearg)
        if not realpath.startswith(config['GEAR_BASE_DIR']):
            syslog.syslog("Invalid repository: not in openshift_root (%s) - %s: (%s)" %
                          (config['GEAR_BASE_DIR'], thearg, realpath))
            print "Invalid repository %s: not in application root" % thearg
            sys.exit(3)

        if not os.path.isdir(realpath):
            syslog.syslog("Invalid repository %s (%s)" %
                          (thearg, realpath))
            print "Invalid repository %s: not a directory" % thearg
            sys.exit(3)
        allargs = [thearg]

    elif basecmd in ('quota',):
        allargs = ['--always-resolve']

    runcon = '/usr/bin/runcon'
    mcs_level = get_mcs_level(os.getuid())

    target_context = 'unconfined_u:system_r:openshift_t:%s' % mcs_level
    actual_context = selinux.getcon()[1]
    if target_context != actual_context:
        print "Invalid context: %s, expected %s\n" % (actual_context, target_context)
        sys.exit(40)
        # This else is left in because at the time of writing this statement
        # We have a patched ssh running.  Remove the exit above and it should
        # work on other platforms.
        os.execv(runcon, [runcon, target_context, cmd] + allargs)
        sys.exit(1)
    else:
        os.execv(cmd, [cmd] + allargs)
        sys.exit(1)
