#!/usr/bin/python -tt

import sys, os
import syslog
import pwd
import syslog
import glob
import base64
import commands
import re
import selinux

EXT_LIB = "/usr/libexec/openshift/lib/util"

commands_map = {
    "git-receive-pack": "/usr/bin/git-receive-pack",
    "git-upload-pack": "/usr/bin/git-upload-pack",
    "snapshot": "/bin/bash",
    "restore": "/bin/bash",
    "tail": "/usr/bin/tail",
    "rhcsh": "/bin/bash",
    "true": "/bin/true",
    "java": "/bin/bash",
    "scp": "/bin/bash",
    "cd": "/bin/bash",
    "set": "/bin/bash",
    "mkdir": "/bin/bash",
    "test": "/bin/bash",
    "rsync": "/bin/bash",
    "ctl_all": "/bin/bash",
    "deploy.sh": "/bin/bash",
    "rhc-list-ports": "/bin/bash",
    "post_deploy.sh": "/bin/bash",
    "quota": "/usr/bin/quota"
}

comment_re = re.compile("#.*$")

#
# Read in uservars variables.
#
def _set_env_uservars(uservars_dir):
    for env in os.listdir(uservars_dir):
        fp = open(os.path.join(uservars_dir, env), 'r')
        env_var = fp.readlines()[0].strip().strip('\'"')
        fp.close()
        os.putenv(env, env_var)
    pass

#
# Read in environment variables
#
def read_env_vars():
#    os.putenv
    envdir = os.path.expanduser('~/.env/')
    for env in os.listdir(envdir):
        if env in ['USER_VARS', 'TYPELESS_TRANSLATED_VARS']:
           continue
        elif os.path.isdir(envdir + env):
           if '.uservars' == env:
              _set_env_uservars(envdir + env)
           continue

        fp = open(os.path.expanduser('~/.env/') + env, 'r')
        env_var = fp.readlines()[0].strip().split('=')[1].strip('\'"')
        fp.close()
        os.putenv(env, env_var)

def get_mcs_level(uid):
    cmd = "/usr/bin/oo-get-mcs-level %s" % (uid)
    ret = commands.getstatusoutput(cmd)
    return ret[1]

def read_config():
  config = {}
  f = open('/etc/openshift/node.conf','r')
  data = f.read()
  f.close()
  lines = data.split("\n")
  for line in lines:
    clean_line = comment_re.sub("", line) # remove comments
    clean_line = clean_line.strip() # remove leading and trailing white space
    if clean_line != "":
      split_line = clean_line.split("=")
      if len(split_line) != 2:
          syslog.syslog("node config error: %s" % (line))
          sys.stderr.write("Error in node configuration")
          sys.exit(2)  # need to set the proper exit code
      value = split_line[1].strip('\'"') # remove quotes from value strings
      config[split_line[0]] = value
  return config

#
# Join the user's cgroup if available
#
def join_cgroup():
    """
    Determine a user's cgroup and join it if possible
    """

    username = pwd.getpwuid(os.getuid())[0]
    pid = os.getpid()

    # this should come from /etc/openshift/node.conf:OPENSHIFT_CGROUP_ROOT
    cgroup_root = "/cgroup/all/openshift"
    cgroup_user = os.path.join(cgroup_root, username)
    cgroup_tasks = os.path.join(cgroup_user, "tasks")

    syslog.syslog("user %s: putting process %d in cgroup %s" % (username, pid, cgroup_root))

    if not os.path.isdir(cgroup_root):
        # raise an exception
        return

    if not os.path.isdir(cgroup_user):
        # raise an exception
        return

    if not os.path.isfile(cgroup_tasks):
        # raise an exception
        return 

    # try:
    taskfile = open(cgroup_tasks, 'w')
    taskfile.write(str(pid) + "\n")
    taskfile.flush()
    taskfile.close()
    # except IOError, e:
    #    write "can't join cgroup" message

if __name__ == '__main__':
    # first self-apply restrictions
    join_cgroup()
    read_env_vars()
    config = read_config()

    orig_cmd = os.environ.get('SSH_ORIGINAL_COMMAND', "rhcsh")
    syslog.syslog(orig_cmd)
    allargs = orig_cmd.split()
    try:
        basecmd = os.path.basename(allargs[0])
        cmd = commands_map[basecmd]
    except:
        # Catch all, just run the command as is via bash.
        cmd = "/bin/bash"
        str = ' '.join(allargs)
        allargs = ['-c', str]
    if basecmd in ('snapshot',):
        # This gets called with "snapshot"
        allargs = ['snapshot.sh']
    if basecmd in ('restore',):
        # This gets called with "restore <INCLUDE_GIT>"
        include_git = False
        if len(allargs) > 1 and allargs[1] == 'INCLUDE_GIT':
            include_git = True

        allargs = ['restore.sh'] 
        if include_git:
            allargs.append('INCLUDE_GIT')
    elif basecmd in ('rhcsh',):
        os.environ["PS1"] = "rhcsh> "
        if len(allargs) < 2:
            allargs = ['--init-file', '/usr/bin/rhcsh', '-i']
        else:
            str = ' '.join(allargs[1:])
            allargs = ['--init-file', '/usr/bin/rhcsh', '-c', str]
    elif basecmd in ('ctl_all',):
        allargs = ['-c', '. /usr/bin/rhcsh > /dev/null ; ctl_all %s' % allargs[-1]]
    elif basecmd in ('java','set','scp', 'cd', 'test', 'mkdir', 'rsync', 'deploy.sh', 'post_deploy.sh', 'rhc-list-ports'):
        str = ' '.join(allargs)
        allargs = ['-c', str]
    elif basecmd in ('tail',):
        files = []
        
        files_start_index = 1
        args = []
        add_follow = True
        if allargs[1] == '--opts':
            files_start_index = 3
            args_str = base64.standard_b64decode(allargs[2])
            args = args_str.split()
            for arg in args:
                if arg.startswith(('..', '/')):
                    print "All paths must be relative: " + arg
                    sys.exit(88)
                elif arg == '-f' or arg == '-F' or arg.startswith('--follow'):
                    add_follow = False

        for glob_list in allargs[files_start_index:]:
            for f in glob.glob(glob_list):
                try:
                    if os.path.islink(f) and os.path.lexists(f):
                        files.append(f)
                    else:
                        files.append(f)
                except OSError, e:
                    print "Error: %s" % e.strerror
                    sys.exit(91)
        if len(files) == 0:
            print "Could not find any files matching glob"
            sys.exit(32)
        allargs = []
        allargs.extend(args)
        if add_follow:
            allargs.append('-f')
        allargs.extend(files)
    elif basecmd in ('git-receive-pack', 'git-upload-pack'):
        # git repositories need to be parsed specially
        thearg = ' '.join(allargs[1:])
        if thearg[0] == "'" and thearg[-1] == "'":
            thearg = thearg.replace("'","")
        thearg = thearg.replace("\\'", "")
        thearg = thearg.replace("//", "/")

        # replace leading tilde (~) with user's home path
        realpath = os.path.expanduser(thearg)
        if not realpath.startswith(config['GEAR_BASE_DIR']):
            syslog.syslog("Invalid repository: not in openshift_root (%s) - %s: (%s)" %
                          (config['GEAR_BASE_DIR'], thearg, realpath))
            print "Invalid repository %s: not in application root" % thearg
            sys.exit(3)

        if not os.path.isdir(realpath):
            syslog.syslog("Invalid repository %s (%s)" % 
                          (thearg, realpath))
            print "Invalid repository %s: not a directory" % thearg
            sys.exit(3)
        allargs = [thearg]

    elif basecmd in ('quota',):
        allargs = []

    runcon = '/usr/bin/runcon'
    mcs_level = get_mcs_level(os.getuid())

    target_context = 'unconfined_u:system_r:openshift_t:%s' % mcs_level
    actual_context = selinux.getcon()[1]
    if target_context != actual_context:
        print "Invalid context"
        sys.exit(40)
        # This else is left in because at the time of writing this statement
        # We have a patched ssh running.  Remove the exit above and it should
        # work on other platforms.
        os.execv(runcon, [runcon, target_context, cmd] + allargs)
        sys.exit(1)
    else:
        os.execv(cmd, [cmd] + allargs)
        sys.exit(1)
