#!/usr/bin/env python
import warnings
warnings.filterwarnings('ignore', '.*negative int.*')
import os, sys, optparse
import logging
import csv
import fsutils
from fsutils import AparcStatsParser, BadFileError, aparclogger

# Original Version - Douglas Greve, MGH
# Rewrite - Krish Subramaniam, MGH
# $Id: aparcstats2table,v 1.20.2.5 2013/02/11 19:48:32 mreuter Exp $

# globals
l = aparclogger

# map of delimeter choices and string literals
delimiter2char = {'comma':',', 'tab':'\t', 'space':' ', 'semicolon':';'}

HELPTEXT = """
Converts a cortical stats file created by recon-all and or
mris_anatomical_stats (eg, ?h.aparc.stats) into a table in which
each line is a subject and each column is a parcellation. By
default, the values are the area of the parcellation in mm2. The
first row is a list of the parcellation names. The first column is
the subject name. If the measure is thickness then the last column
is the mean cortical thickness.

The subjects list can be specified on either of two ways:
  1. Specify each subject after a -s flag 

            -s subject1 -s subject2 ... --hemi lh
  
  2. Specify all subjects after --subjects flag. --subjects does not have
     to be the last argument. Eg:

            --subjects subject1 subject2 ... --hemi lh

By default, it looks for the ?h.aparc.stats file based on the
Killiany/Desikan parcellation atlas. This can be changed with
'--parc parcellation' where parcellation is the parcellation to
use. An alternative is aparc.a2009s which was developed by
Christophe Destrieux. If this file is not found, it will exit
with an error unless --skip in which case it skips this subject
and moves on to the next.

By default, the area (mm2) of each parcellation is reported. This can
be changed with '--meas measure', where measure can be area, volume
(ie, volume of gray matter), thickness, thicknessstd, or meancurv.
thicknessstd is the standard dev of thickness across space.

Example:
 aparcstats2table --hemi lh --subjects 004 008 --parc aparc.a2009s 
    --meas meancurv -t lh.a2009s.meancurv.txt

lh.a2009s.meancurv.txt will have 3 rows: (1) 'header' with the name
of each structure, (2) mean curvature for each structure for subject

The --common-parcs flag writes only the ROIs which are common to all 
the subjects. Default behavior is it puts 0.0 in the measure of an ROI
which is not present in a subject. 

The --parcs-from-file <file> outputs only the parcs specified in the file
The order of the parcs in the file is maintained. Specify one parcellation
per line.

The --report-rois flag, for each subject, gives what ROIs that are present
in atleast one other subject is absent in current subject and also gives 
what ROIs are unique to the current subject.

The --transpose flag writes the transpose of the table. 
This might be a useful way to see the table when the number of subjects is
relatively less than the number of ROIs.

The --delimiter option controls what character comes between the measures
in the table. Valid options are 'tab' ( default), 'space', 'comma' and
'semicolon'.

The --skip option skips if it can't find a .stats file. Default behavior is
exit the program.

The --parcid-only flag writes only the ROIs name in the 1st row 1st column
of the table. Default is hemi_ROI_measure
"""

def options_parse():
    """
    Command Line Options Parser for aparcstats2table
    initiate the option parser and return the parsed object
    """
    parser = optparse.OptionParser(version='$Id: aparcstats2table,v 1.20.2.5 2013/02/11 19:48:32 mreuter Exp $', usage=HELPTEXT)
    
    # help text
    h_sub = '(REQUIRED) subject1 <subject2 subject3..>'
    h_s = ' subjectname'
    h_subf = 'name of the file which has the list of subjects ( one subject per line)'
    h_qdec = 'name of the qdec table which has the column of subjects ids (fsid)'
    h_qdeclong = 'name of the longitudinal qdec table which has the column of tp ids (fsid) and subject templates (fsid-base)'
    h_hemi = '(REQUIRED) lh or rh'
    h_parc = 'parcellation.. default is aparc ( alt aparc.a2009s)'
    h_meas = 'measure: default is area ( alt volume, thickness, thicknessstd, meancurv, gauscurv, foldind, curvind)'
    h_skip = 'if a subject does not have input, skip it'
    h_t = '(REQUIRED) output table file'
    h_deli = 'delimiter between measures in the table. default is tab (alt comma, space, semicolon )' 
    h_parcid = 'do not pre/append hemi/meas to parcellation name'
    h_common = 'output only the common parcellations of all the subjects given'
    h_parcfile = 'filename: output parcellations specified in the file'
    h_roi = 'print ROIs information for each subject'
    h_tr = 'transpose the table ( default is subjects in rows and ROIs in cols)' 
    h_v = 'increase verbosity'

    # Add the options
    parser.add_option('--subjects', dest='subjects' ,action='callback',
                      callback=fsutils.callback_var,  help=h_sub)
    parser.add_option('-s', dest='subjects' ,action='append',
                      help=h_s)
    parser.add_option('--subjectsfile', dest='subjectsfile', help=h_subf)
    parser.add_option('--qdec', dest='qdec', help=h_qdec)
    parser.add_option('--qdec-long', dest='qdeclong', help=h_qdeclong)
    parser.add_option('--hemi', dest='hemi',
                      choices=('lh','rh'), help=h_hemi)
    parser.add_option('-t', '--tablefile', dest='outputfile',
                      help=h_t)
    parser.add_option('-p', '--parc', dest='parc',
                      default='aparc', help=h_parc)
    parser.add_option('-m', '--measure', dest='meas',
                      choices=('area','volume','thickness','thicknessstd','meancurv','gauscurv','foldind','curvind'),
                      default='area', help=h_meas)
    parser.add_option('-d', '--delimiter', dest='delimiter',
                      choices=('comma','tab','space','semicolon'),
                      default='tab', help=h_deli)
    parser.add_option('--skip', action='store_true', dest='skipflag',
                      default=False, help=h_skip)
    parser.add_option('--parcid-only', action='store_true', dest='parcidflag',
                      default=False, help=h_parcid)
    parser.add_option('--common-parcs', action='store_true', dest='commonparcflag',
                      default=False, help=h_common)
    parser.add_option('--parcs-from-file', dest='parcsfile',
                      help=h_parcfile)
    parser.add_option('--report-rois', action='store_true', dest='reportroiflag',
                      default=False, help=h_roi)
    parser.add_option('', '--transpose', action='store_true', dest='transposeflag',
                      default=False, help=h_tr)
    parser.add_option('-v', '--debug', action='store_true', dest='verboseflag',
                      default=False, help=h_v)
    (options, args) = parser.parse_args()
   
    # error check
    if options.subjects is not None:
        if len(options.subjects) < 1:
            print 'ERROR: atleast 1 subject must be provided'
            sys.exit(1)
        
    if options.subjects is None and options.subjectsfile is None and options.qdec is None and options.qdeclong is None: 
        print 'ERROR: Specify one of --subjects, --subjectsfile --qdec or --qdec-long'
        print '       or run with --help for help.'
        sys.exit(1)

    count=0
    if options.subjects is not None:
        count = count+1;
    if options.subjectsfile is not None:
        count = count+1
    if options.qdec is not None:
        count = count+1;
    if options.qdeclong is not None:
        count = count+1;
    if count >1:
        print 'ERROR: Please specify just one of  --subjects, --subjectsfile --qdec or --qdec-long.'
        sys.exit(1)

    if not options.outputfile:
        print 'ERROR: output table name should be specified'
        sys.exit(1)
    if not options.hemi:
        print 'ERROR: hemisphere should be provided (lh or rh)'
        sys.exit(1)
        
    # parse the parcs file
    options.parcs = None
    if options.parcsfile is not None:
        try:
            f = open(options.parcsfile, 'r')
            options.parcs = [line.strip() for line in f]
        except IOError:
            print 'ERROR: cannot read '+ options.parcsfile

    if options.reportroiflag:
        print 'WARNING: --report-rois deprecated. Use -v instead'

    if options.verboseflag:
        l.setLevel(logging.DEBUG)

    return options

"""
Args:
    the parsed 'options' 
Returns:
    a list of tuples of (specifier names ( subjects), path to the corresponding .stats files)
"""
def assemble_inputs(options):
    
    o = options
    specs_paths = []
    # check the subjects dir
    subjdir = fsutils.check_subjdirs()
    print 'SUBJECTS_DIR : %s' %subjdir
        # in case the user gave --subjectsfile argument
    if o.subjectsfile is not None:
        o.subjects=[]
        try:
            sf = open(o.subjectsfile)
            [o.subjects.append(subfromfile.strip()) for subfromfile in sf]
            sf.close()
        except IOError:
            print 'ERROR: the file %s doesnt exist'%o.subjectsfile
            sys.exit(1)
    if o.qdec is not None:
        o.subjects=[]
        try:
            f = open(o.qdec, 'rb')
            dialect = csv.Sniffer().sniff(f.read(1024))
            f.seek(0)
            reader = csv.DictReader(f,dialect=dialect)
            #o.subjects = [row['fsid'] for row in reader]
            for row in reader:
                fsid=row['fsid'].strip()
                if fsid[0] != '#':
                    o.subjects.append(fsid)
            #print o.subjects
            f.close()
        except IOError:
            print 'ERROR: the file %s doesnt exist'%o.qdec
            sys.exit(1)
    if o.qdeclong is not None:
        o.subjects=[]
        try:
            f = open(o.qdeclong, 'rb')
            dialect = csv.Sniffer().sniff(f.read(1024))
            f.seek(0)
            reader = csv.DictReader(f,dialect=dialect)
            #o.subjects = [(row['fsid']+'.long.'+row['fsid-base']) for row in reader]
            for row in reader:
                fsid=row['fsid'].strip()
                if fsid[0] != '#':
                    o.subjects.append(fsid+'.long.'+row['fsid-base'].strip())
            f.close()
        except IOError:
            print 'ERROR: the file %s doesnt exist'%o.qdeclong
            sys.exit(1)
            
    for sub in o.subjects:
        specs_paths.append( (sub,  os.path.join(subjdir, sub, 'stats',
                             o.hemi + '.' + o.parc+'.stats')) )
    return specs_paths

"""
Args: 
    disorganized_table - the table is of the form (specifier, parc_measure_map)
    parcslist - list of parcellation names
    where parc_measure_map is a stable hashtable of keys parcellation names and values the measures.
    The table is disorganized because the length of the parc_measure_map may not be the same for all
    specifiers.
    parcellations present in parcslist are the only parcellations which go in the table.
    if any specifier doesn't have a parcellation, the measure is 0.0
Returns:
    rows - list of specifiers ( subjects)
    columns - list of parcellation names
    table - a stable 2d table of size len(rows) x len(columns)
"""
def make_table2d(disorganized_table, parcslist):
    dt = disorganized_table

    # create an ordered 2d table
    table = fsutils.Ddict(fsutils.StableDict)
    for _spec, _parc_measure_map in dt:
        for parc in parcslist:
            try:
                table[_spec][parc] = _parc_measure_map[parc]
            except KeyError:
                table[_spec][parc] = 0.0

    return [spec for (spec, i) in dt], parcslist, table

"""
Args: 
    parsed options
    disorganized_table - the table is of the form (specifier, parc_measure_map)
    where parc_measure_map is a stable hashtable of keys parcellation names and values the measures.
    The table is disorganized because the length of the parc_measure_map may not be the same for all
    specifiers.
Returns:
    rows - list of specifiers ( subjects)
    columns - list of parcellation names
    table - a stable 2d table of size len(rows) x len(columns)
"""
def sanitize_table(options, disorganized_table):
    o = options
    dt = disorganized_table

    _union = []
    _spec, _parc_measure_map = dt[0]
    intersection = _parc_measure_map.keys()
    for spec, parc_measure_map in dt:
        parcs = parc_measure_map.keys()
        _union.append(parcs)
        intersection = fsutils.intersect_order(intersection, parcs)
        l.debug('-'*20)
        l.debug('Specifier: '+spec)
        l.debug('Intersection upto now:')
        l.debug(intersection)
    #_union is a list of lists. Make it a flat list ( single list )
    temp_union = [item for sublist in _union for item in sublist]
    union = fsutils.unique_union(temp_union)
    l.debug('-'*20)
    l.debug('Union:')
    l.debug(union)

    if o.commonparcflag:
        #write only the common parcs ( intersection )
        row, column, table = make_table2d(dt, intersection)
    else:
        # write all the parcs ever encountered
        # if there's no parcs for a certain .stats file, write the measure as 0.0
        row, column, table = make_table2d(dt, union)

    return row, column, table

def write_table(options, rows, cols, table):
    """
    Write the table from memory to disk. Initialize the writer class.
    """
    tw = fsutils.TableWriter(rows, cols, table)
    r1c1 = '%s.%s.%s' %(options.hemi, options.parc, options.meas)
    tw.assign_attributes(filename=options.outputfile, row1col1=r1c1,
                         delimiter=delimiter2char[options.delimiter] )
    # we might need the hemisphere and measure info in columns as well 
    if not options.parcidflag:
        tw.decorate_col_titles(options.hemi+'_', '_'+options.meas)
    if options.transposeflag:
        tw.write_transpose()
    else:
        tw.write()

if __name__=="__main__":
    # Command Line options are error checking done here
    options = options_parse()
    l.debug('-- The options you entered --')
    l.debug(options) 

    # Assemble the input stats files
    subj_listoftuples = assemble_inputs(options)

    # Init the table in memory
    # is a list containing tuples of the form 
    # [(specifier, segidlist, structlist, measurelist),] for all specifiers
    pretable = []
    
    # Parse the parc.stats files 
    print 'Parsing the .stats files'
    for specifier, filepath in subj_listoftuples:
        try:
            l.debug('-'*20)
            l.debug('Processing file ' + filepath)
            parsed = AparcStatsParser(filepath)
            # parcs filter from the command line
            if options.parcs is not None:
                parsed.parse_only(options.parcs)

            parc_measure_map = parsed.parse(options.meas)
            l.debug('-- Parsed Parcs and Measures --')
            l.debug(parc_measure_map)
        except BadFileError, e:
            if options.skipflag:
                print 'Skipping ' + str(e)
                continue
            else:
                print 'ERROR: The stats file '+str(e)+' is not found or is too small to be a valid statsfile'
                print 'Use --skip flag to automatically skip bad stats files'
                sys.exit(1)
        
        pretable.append( (specifier, parc_measure_map)) 

    # Make sure the table has the same number of cols for all stats files
    # and merge them up, clean them up etc. More in the documentation of the fn.
    print 'Building the table..'
    rows, columns, table = sanitize_table(options, pretable)

    # Write this table ( in memory ) to disk.. function uses TableWriter class
    print 'Writing the table to %s' %options.outputfile
    write_table(options, rows, columns, table)

    # always exit with 0 exit code
    sys.exit(0)
