#!/usr/bin/env python
# -*- coding: latin-1 -*-

#
# merge_stats_tables
#
# script to merge stats tables
#
# Original Author: Martin Reuter
# CVS Revision Info:
#    $Author: mreuter $
#    $Date: 2011/05/10 21:17:29 $
#    $Revision: 1.1.2.5 $
#
# Copyright © 2011 The General Hospital Corporation (Boston, MA) "MGH"
#
# Terms and conditions for use, reproduction, distribution and contribution
# are found in the 'FreeSurfer Software License Agreement' contained
# in the file 'LICENSE' found in the FreeSurfer distribution, and here:
#
# https://surfer.nmr.mgh.harvard.edu/fswiki/FreeSurferSoftwareLicense
#
# Reporting: freesurfer@nmr.mgh.harvard.edu
#
#

import warnings
warnings.filterwarnings('ignore', '.*negative int.*')
import os
import sys
import shlex
import optparse
import logging
import subprocess
import tempfile
import shutil
from subject_info import *
import fsutils

# logging 
ch = logging.StreamHandler()
#create logger
logger = logging.getLogger("merge_stats_tables")
logger.setLevel(logging.INFO)
logger.addHandler(ch)

# globals
l = logger

# map of delimeter choices and string literals
delimiter2char = {'comma':',', 'tab':'\t', 'space':' ', 'semicolon':';'}


HELPTEXT = """

SUMMARY

Merges a set of stats table files into a single stats table.
A stats table is a file where each line is a subject and each column 
is a segmentation or parcellation. The first row is the name of the
table (measure) and a list of the ROI names. The first column is
the subject name.

The subjects list can be specified in one of four ways:
  1. Specify each subject after -s 
  
          -s subject1 -s subject2 ..
  
  2. specify all subjects after --subjects.  
     --subjects does not have to be the last argument. Eg:
     
          --subjects subject1 subject2 ... 

  3. Specify each input file after -i 

          -i subject1/stats/aseg.stats -i subject2/stat/aseg.stats ...
  
  4. Specify all the input stat files after --inputs. --inputs does not have
     to be the last argument. Eg:
       
          --inputs subject1/stats/aseg.stats subject2/stats/aseg.stats ...

The first two methods assume the freesurfer directory structure. The
last two are general and can be used with any stats table input file.
However for inputs, the subject name is not printed in the file (just the
row number). Note that the first two and last two are mutually exclusive.
i.e don't specify --subjects when you are providing --inputs and vice versa.

The --common-segs flag outputs only the segmentations which are common to *all*
the statsfiles. This option is helpful if one or more statsfile contains
segmentations different from the segs of other files ( which results in the
script exiting which is the default behavior ). This option makes the
script to continue.

The --all-segs flag outputs segmentations which are the union of all
segmentations in all statsfiles. This option is helpful if one or more statsfile
contains segs different from the segs of other files ( which results in the
script exiting, the default behavior ). Subjects which don't have a certain
segmentation show a value of 0.

The --segids-from-file <file> option outputs only the segmentations present in the file.
There has to be one segmentation id per line in the file. The output table will maintain the 
order of the segmentation ids

The --segno option outputs only the segmentation id requested.
This is useful because if the number of segmentations is large,
the table becomes huge. The order of the specified seg ids is maintained. 

The --no-segno options doesn't output the segmentations. 
This can be convenient for removing segs that are always empty.

The --transpose flag writes the transpose of the table. 
This might be a useful way to see the table when the number of subjects is
relatively less than the number of segmentations.

The --delimiter option controls what character comes between the measures
in the table. Valid options are 'tab' ( default), 'space', 'comma' and  'semicolon'.

The --skip option skips if it can't find a .stats file. Default behavior is
exit the program.

"""

def options_parse():
    """
    Command Line Options Parser for asegstats2table
    initiate the option parser and return the parsed object
    """
    parser = optparse.OptionParser(version='$Id: merge_stats_tables,v 1.1.2.5 2011/05/10 21:17:29 mreuter Exp $', usage=HELPTEXT)
    
    # help text
    h_sub = '(REQUIRED) subject1 <subject2 subject3..>'
    h_s = ' subjectname'
    h_subf = 'name of the file which has the list of subjects ( one subject per line)'
    h_inp = ' input1 <input2 input3..>'
    h_i = ' inputname'
    h_meas = '(REQUIRED) measure to write in output table'
    h_max = ' maximum segmentation number to report'
    h_segsfile = 'filename : output seg ids specified in the file'
    h_seg = 'segno1 <segno2 segno3..> : only include given segmentation numbers'
    h_noseg = 'segno1 <segno2 segno3..> : exclude given segmentation numbers'
    h_common = 'output only the common segmentations of all the statsfiles given'
    h_all = 'output all the segmentations of the statsfiles given'
    h_intable = 'use `fname` as input (REQUIRED when passing subject ids)'
    h_subdir = 'use `subdir` instead of default "stats/" when passing subject ids'
    h_tr = 'transpose the table (default is subjects in rows and segmentations in cols)' 
    h_t = '(REQUIRED) the output tablefile'
    h_deli = 'delimiter between measures in the table. default is space (alt comma, space, semicolon )' 
    h_skip = 'if a subject does not have stats file, skip it'
    h_v = 'increase verbosity'

    # Add options 
    parser.add_option('--subjects', dest='subjects' ,action='callback',
                      callback=fsutils.callback_var,  help=h_sub)
    parser.add_option('-s', dest='subjects' ,action='append', help=h_s)
    parser.add_option('--subjectsfile', dest='subjectsfile', help=h_subf)
    parser.add_option('--inputs', dest='inputs' ,action='callback',
                      callback=fsutils.callback_var,  help=h_inp)
    parser.add_option('-i', dest='inputs' ,action='append', help=h_i)
    parser.add_option('-t', '--tablefile', dest='outputfile', help=h_t)
    parser.add_option('-m', '--meas', dest='meas', help=h_meas)
#    parser.add_option('--maxsegno', dest='maxsegno', help=h_inp)
#    parser.add_option('--segids-from-file', dest='segidsfile', help=h_segsfile)
#    parser.add_option('--segno', dest='segnos' ,action='callback',
#                      callback=fsutils.callback_var,  help=h_seg)
#    parser.add_option('--no-segno', dest='no_segnos' ,action='callback',
#                      callback=fsutils.callback_var,  help=h_noseg)
    parser.add_option('--common-segs', dest='common_flag' ,action='store_true', default=False, help=h_common)
    parser.add_option('--all-segs', dest='all_flag' ,action='store_true', default=False, help=h_all)
    parser.add_option('--intable', dest='intable', help=h_intable)
    parser.add_option('--subdir', dest='subdir', help=h_subdir)
    parser.add_option('-d', '--delimiter', dest='delimiter',
                      choices=('comma','tab','space','semicolon'),
                      default='space', help=h_deli)
    parser.add_option('', '--transpose', action='store_true', dest='transposeflag',
                      default=False, help=h_tr)
    parser.add_option('--skip', action='store_true', dest='skipflag',
                      default=False, help=h_skip)
    parser.add_option('-v', '--debug', action='store_true', dest='verboseflag',
                      default=False, help=h_v)

    (options, args) = parser.parse_args()
    
    # extensive error checks
    if options.subjects is not None:
        if len(options.subjects) < 1:
            print 'ERROR: subjects are not specified (use --subjects SUBJECTS)'
            sys.exit(1)
        else:
            options.dodirect = False
    
    if options.inputs is not None:
        if len(options.inputs) < 1:
            print 'ERROR: inputs are not specified'
            sys.exit(1)
        else:
            options.dodirect = True

    if options.subjectsfile is not None:
        options.dodirect = False

    if options.subjects is None and options.inputs is None and options.subjectsfile is None: 
        print 'ERROR: Specify one of --subjects, --inputs or --subjectsfile'
        print '       or run with --help for help.'
        sys.exit(1)
    
    if options.subjects is not None and options.inputs is not None:
        print 'ERROR: Both subjects and inputs are specified. Please specify just one '
        sys.exit(1)

    if options.subjects is not None and options.subjectsfile is not None:
        print 'ERROR: Both subjectsfile and subjects are specified. Please specify just one '
        sys.exit(1)
    
    if options.inputs is not None and options.subjectsfile is not None:
        print 'ERROR: Both subjectsfile and inputs are specified. Please specify just one '
        sys.exit(1)
    
    if not options.outputfile:
        print 'ERROR: output table name should be specified (use --tablefile FILE)'
        sys.exit(1)

    if not options.meas:
        print 'ERROR: output measure should be specified (use --meas STRING)'
        sys.exit(1)
        
    if (options.subjects is not None or options.subjectsfile is not None) and options.intable is None:
        print 'ERROR: input filename should be specified when passing subject ids (use --intable NAME)'
        sys.exit(1)
    
    if options.all_flag and options.common_flag:
        print 'ERROR: specify either --all-segs or --common-segs'
        sys.exit(1)

#    if options.segidsfile is not None and options.segnos is not None:
#        print 'ERROR: cannot spec both --segids-from-file and --segnos. Spec just one'
#        sys.exit(1)

#    if options.maxsegno and int(options.maxsegno) < 1:
#        print 'ERROR: maximum number of segs reported shouldn''t be less than 1'
#        sys.exit(1)
    
#    if options.segnos is not None  and len(options.segnos) < 1 :
#        print 'ERROR: segmentation numbers should be specified with that option'
#        sys.exit(1)
    
#    if options.no_segnos is not None and len(options.no_segnos) < 1:
#        print 'ERROR: to be excluded segmentation numbers should be specified with that option'
#        sys.exit(1)

#    # parse the segids file
#    if options.segidsfile is not None:
#        try:
#            f = open(options.segidsfile, 'r')
#            options.segnos = [line.strip() for line in f]
#        except IOError:
#            print 'ERROR: cannot read '+ options.segidsfile
#            sys.exit(1)

    if options.verboseflag:
        l.setLevel(logging.DEBUG)
    
    return options
    
"""-
Args:
    the parsed options
Returns:
    a sequence of paths
assemble_inputs takes the command line parsed options and gives a sequence of paths.
pathN is the corresponding path where that stat file can be found
"""
def assemble_inputs(o):
    specs_paths = []
    # in the case of --inputs specification
    if o.dodirect:
        for count, inp in enumerate(o.inputs):
            specs_paths.append( inp )
    # in the case of --subjects spec or --subjectsfile spec
    else:
        # check subjects dir
        subjdir = fsutils.check_subjdirs()
        print 'SUBJECTS_DIR : %s' %subjdir
        if o.subdir is None:
            o.subdir = 'stats'
        # in case the user gave --subjectsfile argument
        if o.subjectsfile is not None:
            o.subjects=[]
            try:
                sf = open(o.subjectsfile)
                [o.subjects.append(subfromfile.strip()) for subfromfile in sf]
            except IOError:
                print 'ERROR: the file %s doesnt exist'%o.subjectsfile
                sys.exit(1)
        for sub in o.subjects:
            specs_paths.append( os.path.join(subjdir, sub, o.subdir, o.intable) )
    return specs_paths


"""
Args:
make_table2d takes a disorganized table of the form 
(spec1,id_name_map1, measurelist1)
(spec2,id_name_map2, measurelist2)
..
..
specN - either the name of the subject or the number of the stat file
id_name_mapN -  a dict with key[segmentations ids, and values=segmentation names..corresponding to the specN file
measurelistN - list of measures corresponding to the segs
(table is disorganized because lengths of the id_name_mapN ( or  measurelistN ) need not be the same for all N)
and a list of segmentation names segnamelist

Returns:
and returns a proper 2d table ( of datastructure 'Ddict(StableDict)')
with list of specN forming the rows
and seglist forming the columns
and the corresponding measure in table[specN][segidN] 
It also returns the list of specN ( rows ) and seglist(column)

If the specN has no segidN, then the corresponding measure is returned as 0.0


"""
def make_table2d(disorganized_table, segnamelist):
    dt = disorganized_table

    # create an ordered 2d table
    table = fsutils.Ddict(fsutils.StableDict)
    for _spec, _id_name_map, _ml in dt:
        for seg in segnamelist:
            try:
                idindex = _id_name_map.values().index(seg)
                table[_spec][seg] = _ml[idindex]
            except ValueError:
                table[_spec][seg] = 0.0

    return [spec for (spec, i, j) in dt], segnamelist, table

    
    
"""
Args:
sanitize_tables takes in a datastructure of the form 
(spec1, collist1, measurelist1)
(spec2, collist2, measurelist2)
..
..
where 
specN        - usually the name of the subject (or the number of the stat file)
collistN     - list of col entries (segmentations ids)
measurelistN - list of measures corresponding to the segs
               same length as collistN (but can be different for each subject)

Usually the collists for different subjects (N) are not the same because different stats
files will contain slightly different segmentations (some might not be listed). 

- If --common-segs is specified, output the intersection of segs
- If --all-segs is specified, output the union of segs ( put 0.0 as measure wherever appropriate )
- If none of the above is specified but still the lengths of the lists are inconsistent, exit gracefully.

Returns:
returns rows, columns, table
rows    - the sequence which consists of specifiers ( subjects/statsfile numbers)
columns - the sequence which consists of all segmentation names
table   - full 2d table StableDict(StableDict) containing the measure (where all rows have the same col elements)
          first index is the rows, second the cols

"""
def sanitize_table(options, disorganized_table):
    #o = options
    #_t = disorganized_table

    l.debug( '-'*40)
    l.debug('Sanitizing the table')
    
    # check if row id's unique
    row_ids = fsutils.StableDict()
    for row,junk1,junk2 in disorganized_table:
        if row in row_ids:
            print "ERROR: duplicate row found: "+row+", not sure if identical data, therefore stopping!"
            sys.exit(1)
            # future: check if rows are the same, if so, drop one of them, else exit
            #  difficulty, columns can be in different order, so they need to reorderd based on col headers   
        row_ids[row] = True
    
    
    # future allow for placeholders (based on passed option selections)
    
    # now simply merge cols
    all_cols = []
    all_rows = []
    for row, cols, tmp2 in disorganized_table:
        all_cols.append(cols)
        all_rows.append(row)
        
    #all_cols is a list of lists. Make it a flat list ( single list )
    if options.all_flag:
        # create union:
        temp_union = [item for sublist in all_cols for item in sublist]
        all_cols = fsutils.unique_union(temp_union)
    elif options.common_flag:
        # create intersection:
        intersection = all_cols[0]
        for cols in all_cols[1:]:
            intersection = fsutils.intersect_order(intersection, cols)
        all_cols = intersection
    else:
        # check if all cols the same
        firstcols = sorted(all_cols[0])
        for cols in all_cols[1:]:
            if not firstcols == sorted(cols):
                print 'ERROR: All stat files should have the same segmentations'
                print 'If one or more stats file have different segs from others,'
                print 'use --common-segs or --all-segs flag depending on the need.'
                print '(see help)'
                sys.exit(1)
        all_cols = all_cols[0]
            
    
    # future: remove specific cols based on passed options
    
    # go through all rows, pad with zeros if fields are missing
    all_data = fsutils.StableDict()
    zero_data = fsutils.StableDict()
    for col in all_cols:
        zero_data[col] = 0
    for rowid, cols, data in disorganized_table:
        # init with zero
        all_data[rowid] = zero_data.copy()
        # fill existing elements form disorganized table
        for i in range(0,len(cols)):
           all_data[rowid][cols[i]] = data[i]
        #print "all data : " + str(all_data)
        #raw_input("Press Enter to continue...")
  

    return all_rows, all_cols, all_data

        

def write_table(outfile,measure, rows, cols, table,delimiter,transposeflag):
    """
    Write the table from memory to disk. Initialize the writer class.
    """
    tw = fsutils.TableWriter(rows, cols, table)
    r1c1 = 'Measure:%s' %( measure)
    tw.assign_attributes(outfile, row1col1=r1c1,
                         delimiter=delimiter2char[delimiter] )
    if transposeflag:
        tw.write_transpose()
    else:
        tw.write()


def parse_stats_table(intable):

    if not os.path.exists(intable):
        print 'ERROR: '+str(intable)+' not found!'
        raise BadFileError(filename)
        
    fp = open(intable,'r')
    first = fp.readline()    
    headers = first.split()
    measure = headers[0]
    cols = headers[1:]    
    rows = []
    data = []
    for line in fp:
        if line.rfind('#') == -1:
            strlst = line.split()
            rows.append(strlst[0])
            data.append(strlst[1:])
    fp.close() 

    return measure,cols,rows,data

if __name__=="__main__":
    # Command Line options and error checking done here
    options = options_parse()
    l.debug('-- The options you entered --')
    l.debug(options) 

    # Assemble the input stats files
    subj_listofpaths = assemble_inputs(options)

    # Init the table in memory
    # is a list containing tuples of the form 
    # note there is no segid list!
    # [(specifier,structlist, measurelist),] for all specifiers
    pretable = []

    # Parse the stats files 
    print 'Parsing the stats table files'
    for filepath in subj_listofpaths:
        try:
            l.debug('-'*20)
            l.debug('Processing file ' + filepath)

            # read in table (later also check segnos, no_segnos, maxsegnos etc)
            measure,cols,rows,data = parse_stats_table(filepath)
            #print 'cols in '+filepath+' : \n'+str(cols)
            #print
            

#            parsed = AsegStatsParser(filepath)
#            # segs filter from the command line
#            if options.segnos is not None:
#                parsed.parse_only(options.segnos)
#            if options.no_segnos is not None:
#                parsed.exclude_structs(options.no_segnos)
#            if options.maxsegno is not None:
#                parsed.set_maxsegno(options.maxsegno)
#            id_name_map, measurelist = parsed.parse(options.meas)
#            
#            l.debug('-- Parsed Ids, Names --')
#            l.debug(id_name_map)
#            l.debug('-- Measures --')
#            l.debug(measurelist)
        except BadFileError, e:
            if options.skipflag:
                print 'Skipping ' + str(e)
                continue
            else:
                print 'ERROR: The stats table '+str(e)+' is not found or is too small to be a valid stats table'
                print 'Use --skip flag to automatically skip bad stats tables'
                sys.exit(1)
        
        # each row gets a copy of the cols header:
        for i in range(0,len(rows)):
            pretable.append( (rows[i], cols, data[i])) 

    # Make sure the table has the same number of cols for all stats files
    # and merge them up, clean them up etc. More in the documentation of the fn.
    print 'Building the table..'
    rows, columns, table = sanitize_table(options, pretable)
    
    #print "rows: "+str(rows)
    #print
    #print "columns: "+str(columns)
    #print
    #print "table: " +str(table)
    #print

    # Write this table ( in memory ) to disk.. function uses TableWriter class
    print 'Writing the table to %s' %options.outputfile
    write_table(options.outputfile, options.meas, rows, columns, table, options.delimiter, options.transposeflag)

    # always exit with 0 exit code
    sys.exit(0)
