#!/usr/bin/env python3

import argparse
import csv
import itertools
import json
import logging as log
import re
import subprocess
from collections import defaultdict
from pathlib import Path

from packaging import version

# Command line arguments
arg_parser = argparse.ArgumentParser(description='Evaluates the ocr in ocrd project-style folder')
arg_parser.add_argument("function", type=str,
                        help="You can select between evaluate, summarize and visualize or do all in one run",
                        choices=["evaluate", "compare", "summarize", "visualize"])
arg_parser.add_argument("fdir", type=lambda x: Path(x), help="Folderdirectory", nargs='*')
arg_parser.add_argument("-n", "--normalize", help="Normalizes the ocr and gt output", action='store_true')
arg_parser.add_argument("-c", "--combine", help="Combines normalized and raw ocr and gt output", action='store_true')
arg_parser.add_argument("-t", "--filetype", help="Filetype", type=str, default="TXT")
arg_parser.add_argument("-o", "--open", help="Opens the first x pdfs in firefox (if installed)", type=int, default=0)
arg_parser.add_argument("-s", "--show", help="Shows plots", action='store_true')
arg_parser.add_argument("-v", "--verbose", help="Shows information, like glyphe subsitutions", action='store_true')

args = arg_parser.parse_args()

args.norminfo = ".normdata" if args.normalize else ".rawdata"

if args.verbose:
    log.basicConfig(format="%(levelname)s: %(message)s", level=log.DEBUG)
    log.info("Verbose output.")
else:
    log.basicConfig(format="%(levelname)s: %(message)s")


def normalize(fpath: Path, normfolder=None):
    """ Normalizes the gt and ocr files """
    with fpath.open() as fin:
        text = []
        for textline in fin.readlines():
            for sub in [(r"ſ", "s"), (r"/f", ""), (r"/n", ""), (r"(\s[/]\s)|(\s[/])|([/]\s)", "/"), (r"aͤ", "ä"),
                        (r"uͤ", "ü"), (r"oͤ", "ö")]:
                textline = re.sub(*sub, textline)
            if textline.strip() != "":
                text.append(textline)
    subfolder = "GT" if fpath.match("*.gt*") else "OCR"
    if normfolder:
        fpath = normfolder.joinpath(f"EVAL/NORM/{subfolder}/{fpath.name}")
    else:
        fpath = fpath.parent.parent.joinpath(f"EVAL/NORM/{subfolder}/{fpath.name}")
    if not fpath.parent.exists():
        fpath.parent.mkdir(parents=True)
    with fpath.open("w") as fout:
        fout.write("".join(text))
    return fpath


def compare():
    """ Get the files, normalize them, start dinglehopper """
    extension = ".txt" if args.filetype == "TXT" else ".xml"
    for fdir in args.fdir:
        gtfiles = [fpath for fpath in fdir.rglob(f"*/GT/*.gt{extension}")]
        log.info(gtfiles)
        for gtfile in gtfiles:
            if "EVAL" in gtfile.parts: continue
            mainfolder = gtfile.parent.parent
            if args.normalize:
                gtfile = normalize(gtfile, normfolder=mainfolder)
            for ocrfile in [fpath for fpath in
                            mainfolder.glob(f"**/*{gtfile.name.split('.')[0]}*{extension}") if
                            not fpath.match("*.gt.*") and not fpath.match("**/EVAL/**/*")]:
                if args.normalize:
                    ocrfile = normalize(ocrfile, normfolder=mainfolder)
                reportpath = mainfolder.joinpath("EVAL/reports/single/").joinpath(
                    ocrfile.name.split(".", 1)[0] + args.norminfo + '.report')
                if not reportpath.parent.exists():
                    reportpath.parent.mkdir(parents=True)
                if args.filetype == "TXT":
                    log.info(["dinglehopper", gtfile, ocrfile, reportpath])
                    call = f"dinglehopper {gtfile} {ocrfile} {reportpath}"
                else:
                    log.info(
                        f"ocrd-dinglehopper -m {mainfolder}/mets.xml -I {gtfile.parts[-2]}, {ocrfile.parts[-2]} -O {reportpath}")
                    call = f"ocrd-dinglehopper -m {mainfolder}/mets.xml -I {gtfile.parts[-2]}, {ocrfile.parts[-2]} -O {reportpath}"
                subprocess.call(call, shell=True)


def summarize_json():
    """ Summarizes the json dinglehopper-reports """
    for fdir in args.fdir:
        folder = defaultdict(list)
        for rpath in fdir.rglob(f"*{args.norminfo}.report.json"):
            print(rpath)
            folder[rpath.parent].append(rpath)
        for folder, reports in folder.items():
            mainfolder = folder.parts[folder.parts.index("EVAL") - 1]
            sum_report = defaultdict(dict)
            for report in reports:
                with report.open() as fin:
                    data = json.load(fin)
                    sum_report[data["gt"]].update({data["ocr"]: data})
            log.info(sum_report)
            with folder.parent.joinpath(f"{mainfolder}{args.norminfo}.sumreport.json").open('w') as fout:
                json.dump(sum_report, fout, indent=4)


def summarize_csv():
    """ Summarizes the json file to csv, for manual validation """
    for fdir in args.fdir:
        reports = fdir.rglob(f"*{args.norminfo}.report.json")
        res = defaultdict(list)
        keys = ["Book", "File", "Fileformat", "Normalized", "Model", "OCR", "CER", "WER"]
        for report in reports:
            with open(report, "r") as fin:
                data = json.load(fin)
                ocrinfo = data["ocr"].split("/")
                if len(ocrinfo) > 2:
                    res["Book"].append(ocrinfo[-2])
                    res["File"].append("_".join(ocrinfo[-1].split("_")[:2]))
                    res["Fileformat"].append("TXT")
                    res["Normalized"].append("NORM")
                    res["Model"].append("/".join(ocrinfo[-1].split(".")[0].split("_", 3)[2:]).replace("ABBYY", ""))
                    if "ABBYY" in data["ocr"]:
                        res["OCR"].append("ABBYY")
                    else:
                        res["OCR"].append("Tesseract")
                    res["CER"].append(round(data["cer"], 3))
                    res["WER"].append(round(data["wer"], 3))
        with open(fdir.joinpath(f'All_Dinglehopper_Report{args.norminfo}.sumreport.csv'), 'w', newline='') as csvfile:
            writer = csv.writer(csvfile, delimiter="\t")
            writer.writerow(keys)
            writer.writerows(zip(*[res[key] for key in keys]))
        log.info(fdir.joinpath(f'All_Dinglehopper_Report{args.norminfo}.sumreport.json'))
        with open(fdir.joinpath(f'All_Dinglehopper_Report{args.norminfo}.sumreport.json'), 'w', newline='') as fout:
            json.dump(res, fout, indent=4)
        return


def plot(x: list, error: list, factors: list, opath: Path, title: str):
    """ Creates the barchart plot with bokeh"""
    try:
        import bokeh
        from bokeh.io import show, save, output_file
        from bokeh.models import ColumnDataSource, FactorRange, Legend, BoxSelectTool
        from bokeh.plotting import figure
        from bokeh import palettes
        from bokeh.transform import factor_cmap

        if version.parse(bokeh.__version__) < version.parse("1.4"):
            raise ImportWarning

        output_file(opath)

        palette = palettes.viridis(len(factors)) if len(factors) <= 256 else len(factors) * ["#7f2704"]
      # palette = []
      # for factor in factors:
      #     print(factor)
      #     setting = factor.split(" ")[0].split("_")
      #     print(setting)
      #     setting.append("GT4HIST")
      #     setting = setting[1]
      #     print(setting)
      #     palette.extend({"BIN":palettes.viridis(80),
      #                     "BINSEG":palettes.viridis(60),
      #                     "Latin":palettes.viridis(40),
      #                     "WORDSEG":palettes.viridis(20),
      #                     "GT4HIST":palettes.viridis(0)}.get(setting,palettes.viridis(200)))

        #palette = palette[2]
        source = ColumnDataSource(data=dict(x=x, counts=error))

        p = figure(x_range=FactorRange(*x), plot_height=700,title=title.upper(),
                   toolbar_location="left", tools="")

        p.vbar(x='x', top='counts', width=0.9, source=source, line_color="white",
               fill_color=factor_cmap('x', palette=palette, factors=factors, start=1, end=2))
        #new_legend = p.legend[0]
        #p.legend[0].plot = None
        #p.add_layout(new_legend, 'left')
        p.y_range.start = 0
        p.x_range.range_padding = 0.1
        p.xaxis.major_label_orientation = 1
        p.xgrid.grid_line_color = None
        #p.xaxis.major_label_text_font_size = "5pt"
        p.xaxis.group_text_font_size = "7pt"
        #p.xaxis.ticker = [10,20,10,20]

        if args.show:
            show(p)
        else:
            save(p)

    except ImportError as error:
        # Output expected ImportErrors.
        log.error(error.__class__.__name__ + ": " + error.message)
    except ImportWarning as warning:
        # Output expected ImportErrors.
        log.warning(f"The installed version of bokeh is {bokeh.__version__}.\n"
                    f"This can lead to problems, we recommend version 1.4.0 or higher.")


def visualize():
    """ Transforms the dinglehopper reports to barcharts """
    pltdata = {"x": defaultdict(list),
               "cer": defaultdict(list),
               "wer": defaultdict(list),
               "factors": defaultdict(list)}
    for fdir in args.fdir:
        if args.combine: args.norminfo = ""
        reports = fdir.rglob(f"*/**/*{args.norminfo}.sumreport.json")
        for report in sorted(reports):
            log.info(report)
            mainfolder = report.name.split(".")[0]
            with report.open() as fin:
                data = json.load(fin)
            for gtfile in sorted(data.keys()):
                for ocrfile, res in sorted(data[gtfile].items()):
                    fname = Path(res["gt"]).name.split(".")[0]
                    dataform = report.name.split(".")[-3]
                    group = {"477429599":"Fraktur, 16. Jhd.",
                            "493432051":"Fraktur, 17. Jhd.",
                            "493447814":"Antiqua, 17. Jhd.",
                            "417576986":"Antiqua, 18. Jhd.",
                            "511097964":"Fraktur, 19. Jhd."}
                    pltdata["x"][mainfolder].append((group[mainfolder],
                                                     Path(res["ocr"]).name.replace(fname, "").split(".")[0].strip("_- ") + f" ({dataform})"))
                    pltdata["cer"][mainfolder].append(res["cer"])
                    pltdata["wer"][mainfolder].append(res["wer"])
                    pltdata["factors"][mainfolder].append(
                        Path(res["ocr"]).name.replace(fname, "").split(".")[0].strip("_- ") + f" ({dataform})")
            if args.combine: args.norminfo = ".combined"
            for val in ["cer", "wer"]:
                plot(pltdata["x"][mainfolder],
                     pltdata[val][mainfolder],
                     pltdata["factors"][mainfolder],
                     report.parent.joinpath(f"{val}_bnc_report{args.norminfo}.html"),
                     val)
    if args.combine: args.norminfo = ".combined"
    for val in ["cer", "wer"]:
        x, error = zip(*sorted(zip(list(itertools.chain.from_iterable(pltdata["x"].values())),
                                   list(itertools.chain.from_iterable(pltdata[val].values())))))
        plot(x,
             error,
             list(set(itertools.chain.from_iterable(pltdata["factors"].values()))),
             args.fdir[0].joinpath(f"All_{val}_bnc_report{args.norminfo}.html"),
             f"ALL_{val}")
    return


def evaluate():
    compare()
    summarize_json()
    visualize()


if __name__ == "__main__":
    {"evaluate": evaluate, "compare": compare, "summarize": summarize_json, "visualize": visualize}[args.function]()
