#!/usr/bin/python3
# -*- coding: utf-8; py-indent-offset: 4 -*-
#  _______ __   _ ____     __
# |       |  \ | |___ \   / /
# |       |   \| | __) | / /-,_
# |       | |\   |/ __/ /__   _|
# |_______|_| \__|_____|   |_|
#
# @author Markus Birth <markus.birth@weltn24.de>

from argparse import ArgumentParser
import datetime
import pytz
import html
import re
import sys
import urllib.request

class DuplicateKeyException(Exception):
    '''Is raised when trying to add a value for an already defined key.'''

class LogCollector():
    def __init__(self):
        self.info_list  = []
        self.error_list = []
        self.graph_order = []
        self.graph_data = {}
        self.worst_sev  = 0
        self.verbose = False
    
    def set_verbose(self, new_status):
        self.verbose = new_status
    
    def printv(self, message, lvl=1):
        if not self.verbose:
            return
        print(sys._getframe(lvl).f_code.co_name + ": " + message, file=sys.stderr)
        

    def add_info(self, message):
        self.info_list.append(message)
        self.printv("INFO - " + message, 2)

    def add_error(self, severity, message):
        self.error_list.append([severity, message])
        self.printv("ERROR({:d}) - {}".format(severity, message), 2)
        if severity > self.worst_sev:
            self.worst_sev = severity
            self.printv(">>> Increased worst severity to: {:d}".format(severity))

    def add_graph(self, id, value):
        if id in self.graph_order:
            raise DuplicateKeyException()
        self.graph_order.append(id)
        self.graph_data[id] = value

    def get_worst_status(self):
        return self.worst_sev

    def get_worst_errors(self):
        msgs = []
        for e in self.error_list:
            if e[0] < self.worst_sev:
                continue
            msgs.append(e[1])
        return " ".join(msgs)

    def get_infos(self):
        return self.info_list

    def get_perfdata(self):
        if len(self.graph_order) == 0:
            return ""
        output = " |"
        output_arr = []
        for i in self.graph_order:
            output_arr.append("{}={}".format(i.replace(" ", "_"), str(self.graph_data[i])))
        output += " ".join(output_arr)
        return output

class UrlLoader():
    def __init__(self):
        pass

    def do_request(self, req):
        try:
            start = datetime.datetime.now()
            res   = urllib.request.urlopen(req)
            end   = datetime.datetime.now()
        except urllib.error.HTTPError as err:
            reason = err.reason
            if err.code == 404:
                body = err.read().decode("utf-8")
                ref = re.search("Reference&#32;&#35;(.*)$", body, re.MULTILINE)
                if ref:
                    ref_plain = html.unescape(ref.group(1))
                    reason += ", ref#{}".format(ref_plain)
            return err.code, reason, -1
        except urllib.error.URLError as err:
            return -1, err.reason, -1

        diff = end - start
        diff = round(int(diff.microseconds) / 1000)
        return res.status, res, diff

    def get_headers(self, url):
        #req = urllib.request.Request(url=url, method="HEAD")
        req = urllib.request.Request(url=url)   # work around HTTP 405 "Method not allowed"
        status, res, diff = self.do_request(req)
        return status, res.info(), diff

    def get_content(self, url):
        req = urllib.request.Request(url=url)
        status, res, diff = self.do_request(req)
        return status, res, diff


###############################################################################


class CheckHls():
    STATUS_OK   = 0
    STATUS_WARN = 1
    STATUS_CRIT = 2
    STATUS_UNKN = 3

    def __init__(self):
        self.ec = LogCollector()
        self.ul = UrlLoader()

        # Parse commandline and set verbose mode
        self.getOpts()
        self.ec.set_verbose(self.opts["verbose"])

        # Normalise expected bitrates
        exp_rates = map(self.bw_plain, self.opts["exp_rates"])

        # Parse playlists and return found bitrates array (i.e. start the main checks)
        found_rates = self.check_master_playlist(self.opts["url"])
        found_rates_num = list(map(int, found_rates))
        found_rates_num.sort()
        found_rates = list(map(str, found_rates_num))
        self.ec.printv("Found rates: {}".format(repr(found_rates)))

        missing_rates = list(set(exp_rates).difference(found_rates))
        missing_rates_num = list(map(int, missing_rates))
        missing_rates.sort()
        missing_rates = list(map(str, missing_rates_num))
        missing_nice = list(map(self.bw_nice, missing_rates))
        if len(missing_rates) > 0:
            self.ec.add_error(self.STATUS_WARN, "Bandwidth(s) missing from master playlist: {}".format(", ".join(missing_nice)))
        self.ec.printv("Missing rates: {}".format(repr(missing_rates)))

        self.ec.printv("Info msgs: {}".format(repr(self.ec.get_infos())))
        self.ec.printv("Err  msgs: {}".format(repr(self.ec.error_list)))

        status_text = self.ec.get_worst_errors()
        if self.ec.get_worst_status() == self.STATUS_OK:
            found_nice = list(map(self.bw_nice, found_rates))
            status_text = "All ok. Bandwidths: {}".format(", ".join(found_nice))

        print("{}{}\\n{}".format(status_text, self.ec.get_perfdata(), "\\n".join(self.ec.get_infos())))
        sys.exit(self.ec.get_worst_status())

    def getOpts(self):
        parser = ArgumentParser(description="Check HLS stream for expected bitrates.")
        parser.add_argument("-U", "--url", help="URL to master playlist (m3u8)", metavar="URL", dest="url", required=True)
        parser.add_argument("-b", help="Bitrate(s) to check (can be specified multiple times), e.g. 600k or 2M", metavar="BITRATE", action="append", required=True, dest="exp_rates")
        parser.add_argument("-I", help="Ignore date/age of master playlist.", action="store_true", required=False, dest="ignore_master_date")
        parser.add_argument("--verbose", help="Verbose logging to STDERR", action="store_true")

        self.opts = parser.parse_args()
        self.opts = vars(self.opts)

    def check_age(self, age, subject):
        self.ec.printv("{} age: {}".format(subject, str(age)))
        msg = "{} is {} seconds old.".format(subject, str(age))
        self.ec.add_info(msg)
        if age > 60:
            self.ec.add_error(self.STATUS_CRIT, msg)
        elif age > 10:
            self.ec.add_error(self.STATUS_WARN, msg)

    def check_date(self, now, date_rfc2822, subject):
        # Parse RFC2822 date (Fri, 09 Dec 2016 18:20:05 GMT) into datetime, then get age in seconds
        stamp = self.parse_rfc2822(date_rfc2822)
        self.ec.printv("{} remote : {}".format(subject, str(stamp)))
        self.ec.printv("{} locally: {}".format(subject, str(stamp.astimezone(pytz.timezone("CET")))))
        age = now - stamp
        self.check_age(age.total_seconds(), subject)

    def check_url(self, status, res, url, subject=""):
        self.ec.printv("### Now checking: {}".format(url))
        if status != 200:
            self.ec.add_error(self.STATUS_CRIT, "HTTP {} ({}) for {}.".format(status, res, url))
            return False

        header_arr = {}
        for h in res.info().items():
            header_arr[h[0]] = h[1]

        now = datetime.datetime.now(tz=pytz.timezone("CET"))
        self.ec.printv("Headers: " + repr(header_arr))

        if "Content-Length" in header_arr:
            self.ec.add_graph("{}_size".format(subject.split()[0]), int(header_arr["Content-Length"]))

        if not (subject == "Master playlist" and self.opts["ignore_master_date"]):
            # Check Date: header (should always be present)
            self.check_date(now, header_arr["Date"], "{}: Date header".format(subject))

            # Check Last-Modified: header if present
            if 'Last-Modified' in header_arr:
                self.check_date(now, header_arr["Last-Modified"], "{}: Last-Modified header".format(subject))

            # Check Age: header if present
            if 'Age' in header_arr:
                self.check_age(int(header_arr["Age"]), "{}: Age header".format(subject))

        # Check header line
        header = res.readline().decode("utf-8").strip()
        if header != "#EXTM3U":
            self.ec.add_error(self.STATUS_CRIT, "EXTM3U header not found in {}".format(url))
            return False
        return True

    def parse_metaline(self, line, metadata = None):
        if metadata is None:
            metadata = {}

        if line[0] == "#":
            line = line[1:]

        #self.ec.printv("Parsing line: \"{}\"".format(line))
        if ":" in line:
            metaname, payload = line.split(":", 2)
        else:
            metaname = line
            payload  = ""
        #self.ec.printv("Payload: " + payload)
        pairs = payload.split(",")

        # Fix commas in values
        pairs2 = []
        for p in pairs:
            if "=" in p:
                pairs2.append(p)
            else:
                if len(pairs2) > 0:
                    x = pairs2.pop()
                    pairs2.append(x+","+p)
                else:
                    pairs2.append(p)

        if len(pairs2) == 1 and not "=" in pairs2[0]:
            return metaname, pairs2[0]

        # Extract meta data
        for p in pairs2:
            if not "=" in p:
                continue
            k, v = p.split("=", 2)
            metadata[k] = v

        return metaname, metadata

    def parse_rfc2822(self, date_str):
        # Example: Fri, 09 Dec 2016 18:20:05 GMT
        date_date = datetime.datetime.strptime(date_str, "%a, %d %b %Y %H:%M:%S %Z")
        date_tz   = date_str[date_str.rindex(" ")+1:]
        date_date = date_date.replace(tzinfo=pytz.timezone(date_tz))
        return date_date

    def bw_plain(self, bw):
        bw = bw.replace("M", "000000")
        bw = bw.replace("k", "000")
        bw = bw.replace("bit", "")
        return bw

    def bw_nice(self, bw):
        bw = bw.replace("00000000", "00M")
        bw = bw.replace("0000000", "0M")
        bw = bw.replace("000000", "M")
        bw = bw.replace("00000", "00k")
        bw = bw.replace("0000", "0k")
        bw = bw.replace("000", "k")
        return bw + "bit"

    def check_master_playlist(self, url):
        self.ec.printv("Master playlist URL: " + url)
        # Get Master playlist
        found_rates = []

        status, res, dur = self.ul.get_content(url)
        if not self.check_url(status, res, url, "Master playlist"):
            return []

        self.ec.add_info("Master playlist: {}ms for {} ".format(dur, url))
        self.ec.add_graph("Master_resp", dur)

        # #EXTM3U
        # #EXT-X-STREAM-INF:PROGRAM-ID=193181205,BANDWIDTH=200000
        # live/stream3.m3u8
        # #EXT-X-STREAM-INF:PROGRAM-ID=193181205,BANDWIDTH=600000
        # live/stream2.m3u8
        # #EXT-X-STREAM-INF:PROGRAM-ID=193181205,BANDWIDTH=1000000
        # live/stream1.m3u8
        # #EXT-X-STREAM-INF:PROGRAM-ID=1,BANDWIDTH=328000,RESOLUTION=416x234,CODECS="avc1.77.30, mp4a.40.2"

        urlprefix = url[0:url.rindex("/")]
        metadata  = {}
        for line in iter(res.readline, b""):
            line = line.decode("utf-8").strip()
            #print(line)
            if line[0:18] == "#EXT-X-STREAM-INF:":
                ignore, metadata = self.parse_metaline(line, metadata)
            elif line[0] == "#":
                continue
            elif line[0] != "#":
                suburl = line
                if suburl[0:4] != "http":
                    suburl = urlprefix + "/" + suburl
                if ".m3u8" in suburl:
                    found_rates.append(metadata['BANDWIDTH'])
                    self.check_stream_playlist(metadata, suburl)
                else:
                    self.ec.add_error(self.STATUS_WARN, "Expected .m3u8 in master playlist, found: {}".format(line))
                    continue
                metadata = {}

        return list(set(found_rates))

    def check_stream_playlist(self, metadata, url):
        self.ec.printv("Stream URL: " + url)
        self.ec.printv("Stream info: " + repr(metadata))
        self.ec.printv("Stream bw: " + self.bw_nice(metadata['BANDWIDTH']))
        stream_name = self.bw_nice(metadata['BANDWIDTH'])

        # Akamai hack (Primary streams have -p.m3u8, seconday have -b.m3u8)
        if "-b.m3u8" in url:
            stream_name = "B" + stream_name

        status, res, dur = self.ul.get_content(url)
        if not self.check_url(status, res, url, "{} playlist".format(stream_name)):
            return

        self.ec.add_info("{} playlist: {}ms for {} ".format(stream_name, dur, url))
        self.ec.add_graph("{}_resp".format(stream_name), dur)

        # #EXT-X-MEDIA-SEQUENCE:211755
        # #EXT-X-ALLOW-CACHE:NO
        # #EXT-X-VERSION:2
        # #EXT-X-TARGETDURATION:8
        # #EXTINF:8,
        # ../../../../hls-live/streams/eventlive/events/_definst_/live/stream1Num211755.ts

        self.ec.printv("Now parsing playlist...")

        urlprefix = url[0:url.rindex("/")]
        segment_urls = []
        total_duration = 0.0
        metadata  = {}
        for line in iter(res.readline, b""):
            line = line.decode("utf-8").strip()
            #print(line)
            if line[0] == "#":
                header, value = self.parse_metaline(line)
                metadata[header] = value
                continue
            elif line[0] != "#":
                suburl = line
                if suburl[0:4] != "http":
                    suburl = urlprefix + "/" + suburl
                if ".m3u8" in suburl:
                    self.ec.add_error(self.STATUS_WARN, "Expected no further playlist in {}, found: {}".format(url, line))
                    continue
                else:
                    #self.ec.printv("Parsed meta: {}".format(repr(metadata)))
                    #EXTINF:10.0,     = actual duration of next segment (float + comma)
                    #EXT-X-TARGETDURATION:11     = max length over all segments

                    extinf = metadata['EXTINF']
                    if extinf[-1] == ',':
                        extinf = extinf[:-1]
                    segment_urls.append([suburl, float(extinf)])
                    total_duration += float(extinf)/60
                metadata = {}

        self.ec.printv("Found {} segments ({} minutes playtime) in playlist.".format(len(segment_urls), total_duration))

        part_count   = 0
        part_count_success = 0
        part_seg_len = 0.0
        part_dl_ms   = 0.0
        # Only check last (newest) 2 segments
        for segment in segment_urls[-2:]:
            part_dl_this = self.check_stream_part(segment[0], "{} playlist".format(stream_name))
            part_count += 1
            if part_dl_this > 0:
                part_seg_len += segment[1]
                part_dl_ms += part_dl_this
                part_count_success += 1

        if part_dl_ms > 0:
            realtime_factor = part_seg_len * 1000 / part_dl_ms
        else:
            realtime_factor = -10
        warn_icon = ""
        if part_count_success == 0:
            warn_icon = "(!!)"
        elif part_count_success < part_count:
            warn_icon = "(!)"
        self.ec.add_info("{} playlist: Downloaded {:d}{} of {:d} segments ({:.2f} seconds) in {:.0f} ms. ({:.2f}x)".format(stream_name, part_count_success, warn_icon, part_count, part_seg_len, part_dl_ms, realtime_factor))
        self.ec.add_graph("{}_factor".format(stream_name), realtime_factor)

    def check_stream_part(self, url, subject=""):
        self.ec.printv("TS URL: " + url)
        status, msg, dur = self.ul.get_headers(url)
        if status != 200:
            self.ec.add_error(self.STATUS_WARN, "HTTP {} ({}) for segment {} from {}.".format(status, msg, url, subject))
            return -1
        #print(msg)
        return dur

if __name__ == "__main__":
    CheckHls()
