#!/usr/bin/env python

# Script for 'repairing' a FreeSurfer parcellation image
# FreeSurfer's sub-cortical structure segmentation has been observed to be highly variable
#   under scan-rescan conditions. This introduces unwanted variability into the connectome,
#   as the parcellations don't overlap with the sub-cortical segmentations provided by
#   FIRST for the sake of Anatomically-Constrained Tractography. This script determines the
#   node indices that correspond to these structures, and replaces them with estimates
#   derived from FIRST.

# Make the corresponding MRtrix3 Python libraries available
import inspect, math, os, sys
lib_folder = os.path.realpath(os.path.join(os.path.dirname(os.path.realpath(inspect.getfile(inspect.currentframe()))), os.pardir, 'lib'))
if not os.path.isdir(lib_folder):
  sys.stderr.write('Unable to locate MRtrix3 Python libraries')
  sys.exit(1)
sys.path.insert(0, lib_folder)



from mrtrix3 import app, fsl, image, path, run



app.init('Robert E. Smith (robert.smith@florey.edu.au)',
         'In a FreeSurfer parcellation image, replace the sub-cortical grey matter structure delineations using FSL FIRST')
app.cmdline.addCitation('', 'Patenaude, B.; Smith, S. M.; Kennedy, D. N. & Jenkinson, M. A Bayesian model of shape and appearance for subcortical brain segmentation. NeuroImage, 2011, 56, 907-922', True)
app.cmdline.addCitation('', 'Smith, S. M.; Jenkinson, M.; Woolrich, M. W.; Beckmann, C. F.; Behrens, T. E.; Johansen-Berg, H.; Bannister, P. R.; De Luca, M.; Drobnjak, I.; Flitney, D. E.; Niazy, R. K.; Saunders, J.; Vickers, J.; Zhang, Y.; De Stefano, N.; Brady, J. M. & Matthews, P. M. Advances in functional and structural MR image analysis and implementation as FSL. NeuroImage, 2004, 23, S208-S219', True)
app.cmdline.addCitation('', 'Smith, R. E.; Tournier, J.-D.; Calamante, F. & Connelly, A. The effects of SIFT on the reproducibility and biological accuracy of the structural connectome. NeuroImage, 2015, 104, 253-265', False)
app.cmdline.add_argument('parc',   help='The input FreeSurfer parcellation image')
app.cmdline.add_argument('t1',     help='The T1 image to be provided to FIRST')
app.cmdline.add_argument('lut',    help='The lookup table file that the parcellated image is based on')
app.cmdline.add_argument('output', help='The output parcellation image')
options = app.cmdline.add_argument_group('Options for the labelsgmfix script')
options.add_argument('-premasked', action='store_true', default=False, help='Indicate that brain masking has been applied to the T1 input image')
options.add_argument('-sgm_amyg_hipp', action='store_true', default=False, help='Consider the amygdalae and hippocampi as sub-cortical grey matter structures, and also replace their estimates with those from FIRST')
app.parse()

app.checkOutputPath(path.fromUser(app.args.output, False))
image.check3DNonunity(path.fromUser(app.args.t1, False))

if app.isWindows():
  app.error('Script cannot run on Windows due to FSL dependency')

fsl_path = os.environ.get('FSLDIR', '')
if not fsl_path:
  app.error('Environment variable FSLDIR is not set; please run appropriate FSL configuration script')

first_cmd = fsl.exeName('run_first_all')

first_atlas_path = os.path.join(fsl_path, 'data', 'first', 'models_336_bin')
if not os.path.isdir(first_atlas_path):
  app.error('Atlases required for FSL\'s FIRST program not installed;\nPlease install fsl-first-data using your relevant package manager')

# Want a mapping between FreeSurfer node names and FIRST structure names
# Just deal with the 5 that are used in ACT; FreeSurfer's hippocampus / amygdala segmentations look good enough.
structure_map = { 'L_Accu':'Left-Accumbens-area',  'R_Accu':'Right-Accumbens-area',
                  'L_Caud':'Left-Caudate',         'R_Caud':'Right-Caudate',
                  'L_Pall':'Left-Pallidum',        'R_Pall':'Right-Pallidum',
                  'L_Puta':'Left-Putamen',         'R_Puta':'Right-Putamen',
                  'L_Thal':'Left-Thalamus-Proper', 'R_Thal':'Right-Thalamus-Proper' }
if app.args.sgm_amyg_hipp:
  structure_map.update({ 'L_Amyg':'Left-Amygdala',    'R_Amyg':'Right-Amygdala',
                         'L_Hipp':'Left-Hippocampus', 'R_Hipp':'Right-Hippocampus' })

t1_spacing = image.Header(path.fromUser(app.args.t1, False)).spacing()
upsample_for_first = False
# If voxel size is 1.25mm or larger, make a guess that the user has erroneously re-gridded their data
if math.pow(t1_spacing[0] * t1_spacing[1] * t1_spacing[2], 1.0/3.0) > 1.225:
  app.warn('Voxel size of input T1 image larger than expected for T1-weighted images (' + str(t1_spacing) + '); '
           'image will be resampled to 1mm isotropic in order to maximise chance of '
           'FSL FIRST script succeeding')
  upsample_for_first = True

app.makeTempDir()

# Get the parcellation and T1 images into the temporary directory, with conversion of the T1 into the correct format for FSL
run.command('mrconvert ' + path.fromUser(app.args.parc, True) + ' ' + path.toTemp('parc.mif', True))
if upsample_for_first:
  run.command('mrresize ' + path.fromUser(app.args.t1, True) + ' - -voxel 1.0 -interp sinc | mrcalc - 0.0 -max - | mrconvert - ' + path.toTemp('T1.nii', True) + ' -strides -1,+2,+3')
else:
  run.command('mrconvert ' + path.fromUser(app.args.t1, True) + ' ' + path.toTemp('T1.nii', True) + ' -strides -1,+2,+3')

app.gotoTempDir()

# Run FIRST
first_input_is_brain_extracted = ''
if app.args.premasked:
  first_input_is_brain_extracted = ' -b'
run.command(first_cmd + ' -m none -s ' + ','.join(structure_map.keys()) + ' -i T1.nii' + first_input_is_brain_extracted + ' -o first')
fsl.checkFirst('first', structure_map.keys())

# Generate an empty image that will be used to construct the new SGM nodes
run.command('mrcalc parc.mif 0 -min sgm.mif')

# Read the local connectome LUT file
# This will map a structure name to an index
sgm_lut = {}
sgm_lut_file_name = 'FreeSurferSGM.txt'
sgm_lut_file_path = os.path.join(path.sharedDataPath(), path.scriptSubDirName(), sgm_lut_file_name)
with open(sgm_lut_file_path) as f:
  for line in f:
    line = line.rstrip()
    if line and line[0]!='#':
      line = line.split()
      sgm_lut[line[1]] = line[0] # This can remain as a string

# Convert FIRST meshes to node masks
# In this use case, don't want the PVE images; want to threshold at 0.5
mask_list = [ ]
progress = app.progressBar('Generating mask images for SGM structures', len(structure_map))
for key, value in structure_map.items():
  image_path = key + '_mask.mif'
  mask_list.append(image_path)
  vtk_in_path = 'first-' + key + '_first.vtk'
  run.command('meshconvert ' + vtk_in_path + ' first-' + key + '_transformed.vtk -transform first2real T1.nii')
  run.command('mesh2voxel first-' + key + '_transformed.vtk parc.mif - | mrthreshold - ' + image_path + ' -abs 0.5')
  # Add to the SGM image; don't worry about overlap for now
  node_index = sgm_lut[value]
  run.command('mrcalc ' + image_path + ' ' + node_index + ' sgm.mif -if sgm_new.mif')
  if not app.continueOption:
    run.function(os.remove, 'sgm.mif')
    run.function(os.rename, 'sgm_new.mif', 'sgm.mif')
  progress.increment()
progress.done()

# Detect any overlapping voxels between the SGM masks, and set to zero
run.command('mrmath ' + ' '.join(mask_list) + ' sum - | mrcalc - 1 -gt sgm_overlap_mask.mif')
run.command('mrcalc sgm_overlap_mask.mif 0 sgm.mif -if sgm_masked.mif')

# Convert the SGM label image to the indices that are required based on the user-provided LUT file
run.command('labelconvert sgm_masked.mif ' + sgm_lut_file_path + ' ' + path.fromUser(app.args.lut, True) + ' sgm_new_labels.mif')

# For each SGM structure:
# * Figure out what index the structure has been mapped to; this can only be done using mrstats
# * Strip that index from the parcellation image
# * Insert the new delineation of that structure
progress = app.progressBar('Replacing SGM parcellations', len(structure_map))
for struct in structure_map:
  image_path = struct + '_mask.mif'
  index = image.statistic('sgm_new_labels.mif', 'median', '-mask ' + image_path)
  run.command('mrcalc parc.mif ' + index + ' -eq 0 parc.mif -if parc_removed.mif')
  run.function(os.remove, 'parc.mif')
  run.function(os.rename, 'parc_removed.mif', 'parc.mif')
  progress.increment()
progress.done()

# Insert the new delineations of all SGM structures in a single call
# Enforce unsigned integer datatype of output image
run.command('mrcalc sgm_new_labels.mif 0.5 -gt sgm_new_labels.mif parc.mif -if result.mif -datatype uint32')
run.command('mrconvert result.mif ' + path.fromUser(app.args.output, True) + (' -force' if app.forceOverwrite else ''))
app.complete()
