#!/usr/bin/env python

# Copyright (c) 2008-2021 the MRtrix3 contributors.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
#
# Covered Software is provided under this License on an "as is"
# basis, without warranty of any kind, either expressed, implied, or
# statutory, including, without limitation, warranties that the
# Covered Software is free of defects, merchantable, fit for a
# particular purpose or non-infringing.
# See the Mozilla Public License v. 2.0 for more details.
#
# For more details, see http://www.mrtrix.org/.



import json



def usage(cmdline): #pylint: disable=unused-variable
  cmdline.set_author('Lena Dorfschmidt (ld548@cam.ac.uk) and Jakub Vohryzek (jakub.vohryzek@queens.ox.ac.uk) and Robert E. Smith (robert.smith@florey.edu.au)')
  cmdline.set_synopsis('Concatenating multiple DWI series accounting for differential intensity scaling')
  cmdline.add_description('This script concatenates two or more 4D DWI series, accounting for the '
                          'fact that there may be differences in intensity scaling between those series. '
                          'This intensity scaling is corrected by determining scaling factors that will '
                          'make the overall image intensities in the b=0 volumes of each series approximately '
                          'equivalent.')
  cmdline.add_argument('inputs', nargs='+', help='Multiple input diffusion MRI series')
  cmdline.add_argument('output', help='The output image series (all DWIs concatenated)')
  cmdline.add_argument('-mask', metavar='image', help='Provide a binary mask within which image intensities will be matched')



def execute(): #pylint: disable=unused-variable
  from mrtrix3 import MRtrixError #pylint: disable=no-name-in-module, import-outside-toplevel
  from mrtrix3 import app, image, path, run #pylint: disable=no-name-in-module, import-outside-toplevel

  num_inputs = len(app.ARGS.inputs)
  if num_inputs < 2:
    raise MRtrixError('Script requires at least two input image series')

  # check input data
  def check_header(header):
    if len(header.size()) != 4:
      raise MRtrixError('Image "' + header.name() + '" is not a 4D image series')
    if not 'dw_scheme' in header.keyval():
      raise MRtrixError('Image "' + header.name() + '" does not contain a gradient table')
    num_grad_lines = len(header.keyval()['dw_scheme'])
    if num_grad_lines != header.size()[3]:
      raise MRtrixError('Number of lines in gradient table for image "' + header.name() + '" (' + str(num_grad_lines) + ') does not match number of volumes (' + str(header.size()[3]) + ')')

  first_header = image.Header(path.from_user(app.ARGS.inputs[0], False))
  check_header(first_header)
  warn_protocol_mismatch = False
  for filename in app.ARGS.inputs[1:]:
    this_header = image.Header(path.from_user(filename, False))
    check_header(this_header)
    if this_header.size()[0:3] != first_header.size()[0:3]:
      raise MRtrixError('Spatial dimensions of image "' + filename + '" do not match those of first image "' + first_header.name() + '"')
    for field_name in [ 'EchoTime', 'RepetitionTime', 'FlipAngle' ]:
      first_value = first_header.keyval().get(field_name)
      this_value = this_header.keyval().get(field_name)
      if first_value and this_value and first_value != this_value:
        warn_protocol_mismatch = True
  if warn_protocol_mismatch:
    app.warn('Mismatched protocol acquisition parameters detected between input images; ' + \
             'the assumption of equivalent intensities between b=0 volumes of different inputs underlying operation of this script may not be valid')
  if app.ARGS.mask:
    mask_header = image.Header(path.from_user(app.ARGS.mask, False))
    if mask_header.size()[0:3] != first_header.size()[0:3]:
      raise MRtrixError('Spatial dimensions of mask image "' + app.ARGS.mask + '" do not match those of first image "' + first_header.name() + '"')

  # check output path
  app.check_output_path(path.from_user(app.ARGS.output, False))

  # import data to scratch directory
  app.make_scratch_dir()
  for index, filename in enumerate(app.ARGS.inputs):
    run.command('mrconvert ' + path.from_user(filename) + ' ' + path.to_scratch(str(index) + 'in.mif'))
  if app.ARGS.mask:
    run.command('mrconvert ' + path.from_user(app.ARGS.mask) + ' ' + path.to_scratch('mask.mif') + ' -datatype bit')
  app.goto_scratch_dir()

  # extract b=0 volumes within each input series
  for index in range(0, num_inputs):
    run.command('dwiextract ' + str(index) + 'in.mif ' + str(index) + 'b0.mif -bzero')

  mask_option = ' -mask_input mask.mif -mask_target mask.mif' if app.ARGS.mask else ''

  # for all but the first image series:
  #   - find multiplicative factor to match b=0 images to those of the first image
  #   - apply multiplicative factor to whole image series
  # It would be better to not preferentially treat one of the inputs differently to any other:
  #   - compare all inputs to all other inputs
  #   - determine one single appropriate scaling factor for each image based on all results
  #       can't do a straight geometric average: e.g. if run for 2 images, each would map to
  #         the the input intensoty of the other image, and so the two corrected images would not match
  #       should be some mathematical theorem providing the optimal scaling factor for each image
  #         based on the resulting matrix of optimal scaling factors
  filelist = [ '0in.mif' ]
  for index in range(1, num_inputs):
    stderr_text = run.command('mrhistmatch scale ' + str(index) + 'b0.mif 0b0.mif ' + str(index) + 'rescaledb0.mif' + mask_option).stderr
    scaling_factor = None
    for line in stderr_text.splitlines():
      if 'Estimated scale factor is' in line:
        try:
          scaling_factor = float(line.split()[-1])
        except ValueError:
          raise MRtrixError('Unable to convert scaling factor from mrhistmatch output to floating-point number')
        break
    if scaling_factor is None:
      raise MRtrixError('Unable to extract scaling factor from mrhistmatch output')
    filename = str(index) + 'rescaled.mif'
    run.command('mrcalc ' + str(index) + 'in.mif ' + str(scaling_factor) + ' -mult ' + filename)
    filelist.append(filename)

  # concatenate all series together
  run.command('mrcat ' + ' '.join(filelist) + ' - -axis 3 | ' + \
              'mrconvert - result.mif -json_export result_init.json -strides 0,0,0,1')

  # remove current contents of command_history, since there's no sensible
  #   way to choose from which input image the contents should be taken;
  #   we do however want to keep other contents of keyval (e.g. gradient table)
  with open('result_init.json', 'r') as input_json_file:
    keyval = json.load(input_json_file)
  keyval.pop('command_history', None)
  with open('result_final.json', 'w') as output_json_file:
    json.dump(keyval, output_json_file)

  run.command('mrconvert result.mif ' + path.from_user(app.ARGS.output), mrconvert_keyval='result_final.json', force=app.FORCE_OVERWRITE)




# Execute the script
import mrtrix3
mrtrix3.execute() #pylint: disable=no-member
