#!/usr/bin/env python
# +-======-+ 
#  Copyright (c) 2003-2018 United States Government as represented by 
#  the Admistrator of the National Aeronautics and Space Administration.  
#  All Rights Reserved.
#  
#  THIS OPEN  SOURCE  AGREEMENT  ("AGREEMENT") DEFINES  THE  RIGHTS  OF USE,
#  REPRODUCTION,  DISTRIBUTION,  MODIFICATION AND REDISTRIBUTION OF CERTAIN 
#  COMPUTER SOFTWARE ORIGINALLY RELEASED BY THE UNITED STATES GOVERNMENT AS 
#  REPRESENTED BY THE GOVERNMENT AGENCY LISTED BELOW ("GOVERNMENT AGENCY").  
#  THE UNITED STATES GOVERNMENT, AS REPRESENTED BY GOVERNMENT AGENCY, IS AN 
#  INTENDED  THIRD-PARTY  BENEFICIARY  OF  ALL  SUBSEQUENT DISTRIBUTIONS OR 
#  REDISTRIBUTIONS  OF THE  SUBJECT  SOFTWARE.  ANYONE WHO USES, REPRODUCES, 
#  DISTRIBUTES, MODIFIES  OR REDISTRIBUTES THE SUBJECT SOFTWARE, AS DEFINED 
#  HEREIN, OR ANY PART THEREOF,  IS,  BY THAT ACTION, ACCEPTING IN FULL THE 
#  RESPONSIBILITIES AND OBLIGATIONS CONTAINED IN THIS AGREEMENT.
#  
#  Government Agency: National Aeronautics and Space Administration
#  Government Agency Original Software Designation: GSC-15354-1
#  Government Agency Original Software Title:  GEOS-5 GCM Modeling Software
#  User Registration Requested.  Please Visit http://opensource.gsfc.nasa.gov
#  Government Agency Point of Contact for Original Software:  
#  			Dale Hithon, SRA Assistant, (301) 286-2691
#  
# +-======-+ 
"""Check obsys_rc file against available data."""

from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from glob     import glob
from obsys_rc import ObsysRc
from re       import compile
from sys      import stdout
from time     import strftime

global zeros, zeros_
zeros  = "00000000_0000"
zeros_ = "00000000_00"

#.......................................................................
def check(filename="obsys.rc",
          newfile="default",
          errfile="default",
          obslist=["all"],
          ignore_gap=[],
          lastday=7):
   """
   Check obsys_rc file against available data.

   input parameters
   => filename: name of obsys_rc to check
   => newfile: name of output new file; defaults to filename+".new"
   => errfile: name of output error file; defaults to filename+".err"
   => obslist: list of observation classes to process
   => ignore_gap: list of obsclass[<threshold] values where data gaps less
                  than threshold (hours) will be ignored; if obsclass eq "all",
                  then threshold applies to all obsclasses;
                  default threshold = 24 (hours)
   => lastday: number of days before "today" to stop looking for data
   """
   global pattern
   global today
   global zeros, zeros_

   # initializations
   #----------------
   _set_global_pattern_dictionary()
   today = strftime("%Y%m%d_2359")

   border1 = "  #"+"-"*26+"\n"
   border2 = "  #"+"="*26+"\n"
   indefinite = "21001231_1800"

   misfile_summary = []

   # open output files
   #------------------
   if newfile == "default": newfile = filename + ".new"
   if newfile == filename:
      msg = "Cannot write new info back to the same file: {}"
      raise ValueError(msg.format(newfile))

   if errfile == "default": errfile = filename + ".err"
   if errfile == filename:
      msg = "Cannot write error info back to the same file: {}"
      raise ValueError(msg.format(errfile))

   # set ignore_gap thresholds
   #--------------------------
   threshold = {}
   for obsStr in ignore_gap:
      if "<" in obsStr:
         (obsclass, thresh) = obsStr.split("<")
      else:           
         obsclass = obsStr
         thresh = 24
      threshold[obsclass] = thresh

   # set lastdate to look for data
   #------------------------------
   if lastday > 0: lastday = -1 * lastday
   deltaMin = int(lastday)*24*60
   lastdate = incr_datetime(today, deltaMin)[0:11]

   # load input obsys_rc file
   #-------------------------
   obsys = ObsysRc(filename)

   # open output file
   #-----------------
   newfl = open(newfile, mode="w")
   errfl = open(errfile, mode="w")

   # loop thru data from obsys_rc file
   #----------------------------------
   for (obsclass, recvals) in obsys.obsinfo():
      if "all" not in obslist and obsclass not in obslist:
         continue

      # set thresh gap value
      #---------------------
      if threshold.has_key("all"):
         thresh = threshold["all"]
      else:
         thresh = 1
      thresh = threshold.get(obsclass, thresh)

      print("\nChecking {}".format(obsclass))

      # output prolog lines, table start, and comments
      #-----------------------------------------------
      if recvals["prolog"]:
         for line in recvals["prolog"]:
            newfl.write(line+"\n")
            errfl.write(line+"\n")

      if recvals["outtmpl"] == "":
         newfl.write("BEGIN {}\n".format(obsclass))
         errfl.write("BEGIN {}\n".format(obsclass))
      else:
         newfl.write("BEGIN {} => {}\n".format(obsclass, recvals["outtmpl"]))
         errfl.write("BEGIN {} => {}\n".format(obsclass, recvals["outtmpl"]))

      for line in recvals["comments"]:
         newfl.write(line+"\n")
         errfl.write(line+"\n")

      # sort obsclass rows by template
      #----------------------------------------------
      # NOTE: an obsclass can have multiple templates
      #----------------------------------------------
      template_start_stop_list = []
      for row in recvals["rows"]:
         (start_stop, interval, template) = row.split()
         (start, stop) = pattern["start_stop"].search(start_stop).groups()
         start += "00"
         stop  += "00"

         if dict(template_start_stop_list).has_key(template):
            dict(template_start_stop_list)[template].append((start, stop))
         else:
            template_start_stop_list.append((template, [(start, stop)]))

      # process and output rows by template
      #------------------------------------
      nodata = []
      for (template, start_stop_list) in template_start_stop_list:
         print("=> {}".format(template))
         got_data_info = {}

         for (start, stop) in start_stop_list:
            start_ = start[0:11]
            stop_ = stop[0:11]
            line = "  %sz-%sz %s %s\n"%(start_, stop_, interval, template)
            errfl.write(line)

            # look for data
            #--------------
            if template in got_data_info:
               continue

            (start_stop_data, deltaMin, misfiles) = _get_data_info(template, thresh)
            misfile_summary.extend(misfiles)
            got_data_info[template] = 1

            # skip if no data found
            #----------------------
            if start_stop_data[0][0] == zeros:
               msg = "  # NO DATA FOUND!!!\n"
               errfl.write(border2+msg+border1+"\n")
               nodata.append(template)
               continue

            # write actual data availability info to newfile
            #-----------------------------------------------
            hh = deltaMin / 60
            if hh == 0:
               nn = deltaMin
            else:
               nn = deltaMin % (hh*60)
            interval_found = "%02d%02d00"%(hh, nn)

            for (start, stop) in start_stop_data:
                start_ = start[0:11]
                stop_ = stop[0:11]
                if start_ < lastdate and stop >= lastdate:
                   stop_ = indefinite[0:11]

                line = "  %sz-%sz %s %s\n"%(start_, stop_, interval_found, template)
                if start == zeros:
                    line = "#" + line[1:]
                newfl.write(line)

            # write interval discrepancy info to check output
            #------------------------------------------------
            if interval != interval_found:
               msg  = "  # INCORRECT INTERVAL\n"+border1
               msg += "  # interval listed: {}\n".format(interval)
               msg += "  # interval found:  {}\n".format(interval_found)
               errfl.write(border2+msg+border1+"\n")

            # substitute lastdate for indefinite
            #-----------------------------------
            if start_stop_list[-1][1] == indefinite:
               start_stop_list_max = lastdate
            else:
               start_stop_list_max = start_stop_list[-1][1]

            # look for data discrepancies
            #----------------------------
            first = min(start_stop_list[0][0],  start_stop_data[0][0])
            final = max(start_stop_list_max, start_stop_data[-1][1])

            if final > today:
                final = today

            if final < first:
                msg = "first datetime is greater than final: {} > {}"
                raise ValueError(msg.format(first, final))

            miss_data = []
            found_data = []
            datetime = first

            while datetime <= final:

               # data listed but not found
               #--------------------------
               if _included(start_stop_list, datetime):
                  if not _included(start_stop_data, datetime):
                     miss_data.append(datetime)

               # data found but not listed
               #--------------------------
               elif _included(start_stop_data, datetime):
                  found_data.append(datetime)

               datetime = incr_datetime(datetime, deltaMin)
                
            # write miss data info to check output
            #-------------------------------------
            if miss_data:
                msg  = "  # MISSING DATA\n"
                msg += "  # gap start < %sz\n"%(lastdate)
                miss_tuples = _start_stop_tuples(miss_data, deltaMin)

                wrote_info = False
                for (start, stop) in miss_tuples:

                    if start < lastdate:
                        if not wrote_info:
                            errfl.write(border2+msg+border1)
                            wrote_info = True

                        start_ = start[0:11]
                        stop_ = stop[0:11]

                        if start == stop:
                            errfl.write("  %sz\n"%(start_))
                        else:
                            errfl.write("  %sz-%sz\n"%(start_, stop_))

                if wrote_info:
                    errfl.write(border1+"\n")

            # write extra found data info to errfl
            #-------------------------------------
            if found_data:
                msg = "  # MORE DATA FOUND\n"
                errfl.write(border2+msg+border1)

                nogap_tuples = _start_stop_tuples(found_data, deltaMin)
                for (start, stop) in nogap_tuples:
                    if start == stop:
                        errfl.write("  %sz\n"%(start[0:11]))
                    else:
                        errfl.write("  %sz-%sz\n"%(start[0:11], stop[0:11]))
                errfl.write(border1+"\n")

            # write misfile info to check output
            #-----------------------------------
            if misfiles:
                msg  = "  # MISFILES\n"
                errfl.write(border2+msg+border1)
                misfiles.sort()
                for mis in misfiles:
                    errfl.write("  #"+mis+"\n")
                errfl.write(border1+"\n")

      # output table end
      #-----------------
      for template in nodata:
         line = "# %sz-%sz %s %s\n"%(zeros_, zeros_, "000000", template)
         newfl.write(line)
      newfl.write("END\n")
      newfl.flush()

      errfl.write("END\n")
      errfl.flush()

   newfl.close()
   errfl.close()

   # write misfile summary to stdout
   #--------------------------------
   if misfile_summary:
       msg  = "\n# MISLABELED and MISPLACED files\n" \
           +    "#-------------------------------\n"
       stdout.write(msg)
       misfile_summary.sort()
       for mis in misfile_summary:
           stdout.write(mis+"\n")

#.......................................................................
def _csplit(strval, char=","): return strval.split(char)

#.......................................................................
def _get_data_info(template, thresh):
   """
   Return list of (start, stop) datetime tuples, plus delta value (in minutes)
   and list of misplaced and misnamed data.

   input parameters
   => template: data path/name template
   => thresh: data gap threshold (hours); ignore data gaps < thresh
   """
   global zeros

   # regular expression patterns to find date/time
   #----------------------------------------------
   global pattern

   # get list of data filepaths
   #---------------------------
   index = template.find(":")+1
   tmpl = template[index:]
   tmpl = tmpl.replace("%y4", "????").replace("%y2", "??").replace("%m2", "??")
   tmpl = tmpl.replace("%d2", "??").replace("%h2", "??").replace("%n2", "??")
   tmpl = tmpl.replace("%j3", "???").replace("%c", "?")

   filepath_list = []
   filepath_list = glob(tmpl)

   # extract datetimes from available data
   #--------------------------------------
   times_found = {}
   datetime_list = []
   misfiles = []

   if filepath_list:
      for fpath in filepath_list:
         year  = None
         month = None
         day   = None
         jjj   = None
         hour  = None
         min   = "00"

         # extract date/time from filename
         #--------------------------------
         if pattern["Ayyyyjjj_hhnn"].search(fpath):
            returnVals = pattern["Ayyyyjjj_hhnn"].search(fpath).groups()
            (year, jjj, hour, min) = returnVals

         elif pattern["yyyymmddhh"].search(fpath):
            returnVals = pattern["yyyymmddhh"].search(fpath).groups()
            (year, month, day, hour) = returnVals

         elif pattern["yyyymmdd_hhnnz"].search(fpath):
            returnVals = pattern["yyyymmdd_hhnnz"].search(fpath).groups()
            (year, month, day, hour, min) = returnVals

         elif pattern["yyyymmdd_hh"].search(fpath):
            returnVals = pattern["yyyymmdd_hh"].search(fpath).groups()
            (year, month, day, hour) = returnVals

         elif pattern["yyyymmdd__hhz"].search(fpath):
            returnVals = pattern["yyyymmdd__hhz"].search(fpath).groups()
            (year, month, day, hour) = returnVals

         elif pattern["yymmdd__hhz"].search(fpath):
            returnVals = pattern["yymmdd__hhz"].search(fpath).groups()
            (yy, month, day, hour) = returnVals
            if int(yy ) > 60: year = "19"+yy
            else:             year = "20"+yy

         elif pattern["yyyymmdd"].search(fpath):
            returnVals = pattern["yyyymmdd"].search(fpath).groups()
            (year, month, day) = returnVals
            hour = "00"

         elif pattern["yyyyjjj_hhnn"].search(fpath):
            returnVals = pattern["yyyyjjj_hhnn"].search(fpath).groups()
            (year, jjj, hour, min) = returnVals

         elif pattern["Y4_M2_D2"].search(fpath) and pattern["hhz"].search(fpath):
            (year, month, day) = pattern["Y4_M2_D2"].search(fpath,1).groups()
            (hour,) = pattern["hhz"].search(fpath).groups()

         else:
            msg = "Cannot extract date/time from filename in fpath: {}"
            raise ValueError(msg.format(fpath))

         # check for misplaced or mislabeled files
         #----------------------------------------
         if pattern["YYYY_JJJ"].search(fpath):
            (YYYY, JJJ) = pattern["YYYY_JJJ"].search(fpath).groups()

            if jjj == None:
               msg = "JJJ (day-of_year) found in path but not in filename: {}"
               raise ValueError(msg.format(fpath))

            if year != YYYY or jjj != JJJ:
               misfiles.append("(MISPLACED) "+fpath)
               continue

            (month, day) = jjj2mmdd(year, jjj)

         elif pattern["Y4_M2"].search(fpath):
            (yyyy, mm) = pattern["Y4_M2"].search(fpath).groups()

            if year != yyyy or month != mm:
               misfiles.append("(MISPLACED) "+fpath)
               continue

         if int(month) < 1 or int(month) > 12:
            misfiles.append("(MISLABELED) "+fpath)
            continue

         numdays = num_days_in_month(int(year), int(month))
         if int(day) < 1 or int(day) > numdays:
            misfiles.append("(MISLABELED) "+fpath)
            continue
                 
         if int(hour) > 24:
            misfiles.append("(MISLABELED) "+fpath)
            continue
                 
         if int(min) > 60:
            misfiles.append("(MISLABELED) "+fpath)
            continue
                 
         datetime = year+month+day+"_"+hour+min
         if (datetime > today):
            misfiles.append("(FUTURE_DATE?) "+fpath)
            continue

         # add file date/time to list
         #---------------------------
         datetime_list.append(datetime)
         times_found[hour+min] = 1

   # determine deltaMin
   #-------------------
   if times_found:
      num_times = len(times_found.keys())
      deltaMin = (24.*60.)/num_times
   
      if deltaMin != int(deltaMin):
         msg = "Non-divisible number of hour+min times found: {} for {}"
         raise ValueError(msg.format(num_times, template))
      deltaMin = int(deltaMin)
   else:
      datetime_list = [zeros]
      deltaMin = 24*60

   gapMin = max(deltaMin, thresh*60)
   start_stop_data = _start_stop_tuples(datetime_list, gapMin)

   return (start_stop_data, deltaMin, misfiles)

#.......................................................................
def _start_stop_tuples(datetime_list, deltaMin):
    datetime_list.sort()

    start = datetime_list[0]
    previous = start
    next = incr_datetime(start, deltaMin)

    tuple_list = []
    for datetime in datetime_list[1:]:
        if datetime > next:
            tuple_list.append((start, previous))
            start = datetime

        previous = datetime
        next = incr_datetime(datetime, deltaMin)

    tuple_list.append((start, previous))

    return tuple_list

#.......................................................................
def _included(date_ranges, datetime):
    """Return True if datetime is in any of the included date_ranges"""
    for (start, stop) in date_ranges:
        if datetime >= start and datetime <= stop:
            return True
    return False

#.......................................................................
def _set_global_pattern_dictionary():
    """Create patterns to be used in regexp searches to find date/time."""
    global pattern
    pattern = {}

    # regular expression strings
    #---------------------------
    start_stop_string     = r"(\d{8}_\d{2})z-(\d{8}_\d{2})z"

    Ayyyyjjj_hhnn_string  = r"A(\d{4})(\d{3})\D(\d{2})(\d{2})\D"
    yyyymmddhh_string     = r"\D(\d{4})(\d{2})(\d{2})(\d{2})"
    yyyymmdd_hhnnz_string = r"\D(\d{4})(\d{2})(\d{2})\D(\d{2})(\d{2})z"
    yyyymmdd_hhz_string   = r"\D(\d{4})(\d{2})(\d{2})\D(\d{2})z"
    yyyymmdd_hh_string    = r"\D(\d{4})(\d{2})(\d{2})\D(\d{2})"
    yyyymmdd__hhz_string  = r"\D(\d{4})(\d{2})(\d{2})\D+(\d{2})z"
    yymmdd__hhz_string    = r"\D(\d{2})(\d{2})(\d{2})\D+(\d{2})z"
    yyyymmdd_string       = r"\D(\d{4})(\d{2})(\d{2})"
    yyyyjjj_hhnn_string   = r"\D(\d{4})(\d{3})\D(\d{2})(\d{2})"
    hhz_string            = r"\D(\d{2})z"

    Y4_M2_D2_string       = r"/Y(\d{4})/M(\d{2})/D(\d{2})/"
    Y4_M2_string          = r"/Y(\d{4})/M(\d{2})/"
    YYYY_JJJ_string       = r"/(\d{4})/(\d{3})/"


    # store regular expression patterns in global variable
    #-----------------------------------------------------
    pattern["start_stop"]     = compile(start_stop_string)

    pattern["Ayyyyjjj_hhnn"]  = compile(Ayyyyjjj_hhnn_string)
    pattern["yyyymmddhh"]     = compile(yyyymmddhh_string)
    pattern["yyyymmdd_hhnnz"] = compile(yyyymmdd_hhnnz_string)
    pattern["yyyymmdd_hhz"]   = compile(yyyymmdd_hhz_string)
    pattern["yyyymmdd_hh"]    = compile(yyyymmdd_hh_string)
    pattern["yyyymmdd__hhz"]  = compile(yyyymmdd__hhz_string)
    pattern["yymmdd__hhz"]    = compile(yymmdd__hhz_string)
    pattern["yyyymmdd"]       = compile(yyyymmdd_string)
    pattern["yyyyjjj_hhnn"]   = compile(yyyyjjj_hhnn_string)
    pattern["hhz"]            = compile(hhz_string)

    pattern["Y4_M2_D2"]       = compile(Y4_M2_D2_string)
    pattern["Y4_M2"]          = compile(Y4_M2_string)
    pattern["YYYY_JJJ"]       = compile(YYYY_JJJ_string)

#.......................................................................
def incr_datetime(datetime, deltaMin):
    """
    Increment datetime by delta minutes and return the value.

    input parameters
    => datetime: date/time in yyyymmdd_hhnnz format
    => deltaMin: integer number of minutes to add to datetime

    return value
    => new_datetime: format yyyymmdd_hhnnz
    """
    year  = int(datetime[0:4])
    month = int(datetime[4:6])
    day   = int(datetime[6:8])
    try:
        hour = int(datetime[9:11])
        min  = int(datetime[11:13])
    except ValueError:
        msg = "EXCEPTION: datetime = {}"
        raise ValueError(msg.format(datetime))

    min += deltaMin
    while min > 59:
        min -= 60
        hour += 1

    while min < 0:
       min += 60
       hour -= 1

    while hour > 23:
        hour -= 24
        day += 1

        if day > num_days_in_month(year, month):
            day = 1
            month += 1

            if month > 12:
                month = 1
                year += 1

    while hour < 0:
        hour += 24
        day -= 1

        if day < 1:
            month -= 1

            if month < 1:
                month = 12
                year -=1

            day = num_days_in_month(year, month)

    return "%04d%02d%02d_%02d%02d"%(year, month, day, hour, min)

#.......................................................................
def num_days_in_month(year, month):
    numdays = [0, 31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]

    if month==2 and year%4==0 and (year%100!=0 or year%400==0):
        numdays[2] = 29

    return numdays[month]

#.......................................................................
def jjj2mmdd(year, jjj):
    dd = int(jjj)
    for mm in range(1, 13):
        last = num_days_in_month(int(year), mm)

        if dd <= last:
            break

        if mm == 12:
            msg = "Day-of-year value is too large: year = {}, jjj = {}"
            raise ValueError(msg.format(year, jjj))

        dd -= last

    month = str("%02d"%mm)
    day = str("%02d"%dd)

    return (month, day)       

#.......................................................................
if __name__ == "__main__":
    """Check obsys_rc file against available data."""

    # get calling parameters
    #-----------------------
    ArgumentDefaults = ArgumentDefaultsHelpFormatter
    parser = ArgumentParser(description=__doc__,
                            formatter_class=ArgumentDefaults)
    #=========
    # filename
    #=========
    parser.add_argument("filename",
                        nargs="?",
                        type=str,
                        default="obsys.rc",
                        help="name of obsys_rc file")
    #========
    # newfile
    #========
    parser.add_argument("newfile",
                        nargs="?",
                        type=str,
                        default='filename+".new"',
                        help="name of output new file")
    #========
    # errfile
    #========
    parser.add_argument("errfile",
                        nargs="?",
                        type=str,
                        default='filename+".err"',
                        help="name of output error file")
    #==========
    # --obslist
    #==========
    parser.add_argument("--obslist",
                        metavar="obsclass_list",
                        type=_csplit,
                        default="all",
                        help="""list of observation classes to process,
                                separated by commas, no spaces""")
    #=============
    # --ignore_gap
    #=============
    parser.add_argument("--ignore_gap",
                        metavar="obsclass_threshold_list",
                        type=_csplit,
                        default=[],
                        help="""list of obsclass[<threshold] values where data
                                gaps less than threshold (hours) will be ignored;
                                multiple values separated by commas, no spaces;
                                use obsclass = "all" to apply threshold to all
                                obsclasses; default threshold = 24 (hours)""")
    #==========
    # --lastday
    #==========
    help_msg = """number of days before "today" to stop looking for data""";
    parser.add_argument("--lastday",
                        nargs="?",
                        type= int,
                        default=7,
                        help = help_msg)
    # extract calling parameters
    #---------------------------
    args = parser.parse_args()

    filename   = args.filename
    obslist    = args.obslist
    ignore_gap = args.ignore_gap
    lastday    = args.lastday

    if args.newfile == 'filename+".new"': newfile = filename+".new"
    else:                                 newfile = args.newfile

    if args.errfile == 'filename+".err"': errfile = filename+".err"
    else:                                 errfile = args.errfile

    # call check function
    #--------------------
    check(filename, newfile, errfile, obslist, ignore_gap, lastday)
