'''
    LaTeX linter.

    Checks for common mistakes such as:
        * Unreferenced figures, tables, etc.
        * Nonexistent .bib entries for citations
        * Use of unwanted commands

    This linter is based on regular expressions run on each line.
    Items spanning multiple lines will not be found.

    Author: Seth Ebner
    Created: 9 December 2019
    Last Modified: 23 January 2020
'''

_version = '1.0'

import argparse
import os
import re
import sys

from collections import Counter, defaultdict, namedtuple

Referable = namedtuple('Referable', 'name kind document line')
Reference = namedtuple('Reference', 'name kind document line')
Citation = namedtuple('Citation', 'name kind document line')
Citable = namedtuple('Citable', 'name kind document line')
Call = namedtuple('Call', 'name document line')
Identifier = namedtuple('Identifier', 'name kind')

def memoize(func):
    cache = dict()

    def memoized_func(*args):
        if args in cache:
            return cache[args]
        result = func(*args)
        cache[args] = result
        return result

    return memoized_func

def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument('--dir', type=str, required=True, help='input directory')
    p.add_argument('--main', type=str, default='main.tex', help='main .tex file (relative to dir)')
    p.add_argument('--commands', type=str, nargs='+', default=[], help='LaTeX commands that should not be active')
    p.add_argument('--not-referenced', action='store_true', default=False, help='if specified, prints warnings that labeled document elements are not referenced')
    p.add_argument('--no-reference', action='store_true', default=False, help='if specified, prints warnings that references to document elements could not be resolved')
    p.add_argument('--not-cited', action='store_true', default=False, help='if specified, prints warnings that .bib entries are not cited')
    p.add_argument('--no-bib-entry', action='store_true', default=False, help='if specified, prints warnings that citations have no .bib entry')
    p.add_argument('--mult-bib-entry', action='store_true', default=False, help='if specified, prints warnings that a .bib entry is defined more than once')
    p.add_argument('--mult-def', action='store_true', default=False, help='if specified, prints warnings that a document element is defined more than once')
    p.add_argument('--print-citation-counts', action='store_true', default=False, help='if specified, prints number of citations of each document')
    p.add_argument('--print-document-element-summary', action='store_true', default=False, help='if specified, prints summary of document element mentions')
    args = p.parse_args()

    return args

def remove_comments(line):
    candidate_comment_starts = [m.start() for m in re.finditer('%', line)]  # LaTeX comments begin with '%' unless escaped as '\%'
    escaped_positions = [m.start()+1 for m in re.finditer('\\\%', line)]  # search for '\%', in which case '%' appears one position after the match start

    comment_start = None
    comment_starts = sorted(set(candidate_comment_starts) - set(escaped_positions))  # positions in which an unescaped '%' occurs, sorted left-to-right
    if len(comment_starts) > 0:
        comment_start = comment_starts[0]

    return line[:comment_start]

@memoize
def clean_file(file_):
    with open(file_, "r") as f:
        lines = f.readlines()

    lines = [remove_comments(line) for line in lines]
    lines = [line.rstrip() for line in lines]

    return lines

def hash(x, fmt=False):
    """
    Identifier for a document element, used for comparisons.
    """
    t = type(x).__name__
    if t in ['Call', 'Citable', 'Citation']:
        if fmt:
            return x.name
        else:
            # element does not have a `kind` distinguishable from its name
            return Identifier(name=x.name, kind=None)
    elif t in ['Referable', 'Reference']:
        if fmt:
            return f"{x.kind}:{x.name}"
        else:
            # parsed from, e.g., \ref{kind:name}
            # name=unique identifier, kind=fig, tab, etc.
            return Identifier(name=x.name, kind=x.kind)
    else:
        raise TypeError(f"unrecognized type: {t}")

def location(x, fmt=False):
    """
    Location of a document element.
    """
    t = type(x).__name__
    if t in ['Call', 'Citable', 'Citation', 'Referable', 'Reference']:
        if fmt:
            return f"{x.document}:{x.line}"
        else:
            return (x.document, x.line)
    else:
        raise TypeError(f"unrecognized type: {t}")

def get_calls(file_, commands):
    lines = clean_file(file_)

    calls = defaultdict(list)  # may have multiple calls of a command on the same line (near-duplicates allowed)
    for i, line in enumerate(lines):
        for command in commands:
            if command in line and not any(keyword in line for keyword in ['\\newcommand', '\\renewcommand', '\\def']):
                call = Call(name=command, document=file_, line=i+1)
                calls[hash(call)].append(call)

    return calls

def get_referables(file_):
    lines = clean_file(file_)

    referables = defaultdict(list)
    for i, line in enumerate(lines):
        referable_matches = re.finditer(r'\\label{\s*([^:]*)\s*:\s*(.*?)\s*}', line)  # search for all '\label{x: y}' in line
        for referable_m in referable_matches:
            referable_kind = referable_m.group(1)  # x
            referable_name = referable_m.group(2)  # y

            referable = Referable(name=referable_name, kind=referable_kind, document=file_, line=i+1)
            referables[hash(referable)].append(referable)

    return referables

def get_references(file_):
    lines = clean_file(file_)

    references = defaultdict(list)
    for i, line in enumerate(lines):
        reference_matches = re.finditer(r'\\(autoref|ref){\s*([^:]*)\s*:\s*(.*?)\s*}', line)
        for reference_m in reference_matches:
            reference_kind = reference_m.group(2)
            reference_name = reference_m.group(3)

            reference = Reference(name=reference_name, kind=reference_kind, document=file_, line=i+1)
            references[hash(reference)].append(reference)

    return references

def get_citations(file_):
    lines = clean_file(file_)

    citations = defaultdict(list)
    for i, line in enumerate(lines):
        citation_matches = re.finditer(r'\\(cite|newcite|shortcite|citep|citet|citeyearpar){\s*(.*?)\s*}', line)
        for citation_m in citation_matches:
            citation_kind = citation_m.group(1)
            citation_names = citation_m.group(2)
            citation_names = re.sub(r'\s+', '', citation_names).split(',')  # remove all whitespace, then split on comma

            for citation_name in citation_names:
                citation = Citation(name=citation_name, kind=citation_kind, document=file_, line=i+1)
                citations[hash(citation)].append(citation)

    return citations

def get_citables(file_):
    lines = clean_file(file_)

    citables = defaultdict(list)
    for i, line in enumerate(lines):
        # TODO: handle cases where name of paper may be on next line from @article (etc.)
        citable_matches = re.finditer(r'@(article|book|incollection|inproceedings|MISC|phdthesis){\s*(.*?)\s*,', line)
        for citable_m in citable_matches:
            citable_kind = citable_m.group(1)
            citable_name = citable_m.group(2)

            citable = Citable(name=citable_name, kind=citable_kind, document=file_, line=i+1)
            citables[hash(citable)].append(citable)

    return citables

def get_included_tex_files(file_, dir_):
    lines = clean_file(file_)

    inclusions = set()
    for line in lines:
        inclusion_matches = re.finditer(r'\\(input|include){\s*(.*?)\s*}', line)
        for inclusion_m in inclusion_matches:
            inclusion_name = inclusion_m.group(2)

            if inclusion_name.endswith('.tex'):
                included_file = os.path.join(dir_, inclusion_name)  # get full path
                inclusions.add(included_file)

    return inclusions

def get_all_included_tex_files(file_, dir_):
    included_tex_files = set()
    included_files = get_included_tex_files(file_, dir_)
    for included_file in included_files:
        included_tex_files.add(included_file)
        included_tex_files.update(get_all_included_tex_files(included_file, dir_))  # recurse

    return included_tex_files

def get_bibfiles(file_):
    lines = clean_file(file_)

    bibfiles = []
    for line in lines:
        bibfile_matches = re.finditer(r'\\bibliography{\s*(.*?)\s*}', line)
        for bibfile_m in bibfile_matches:
            bibfile_name = bibfile_m.group(1)
            bibfiles.append(bibfile_name + '.bib')

    return bibfiles

def lint_file(file_, commands):
    file_calls = get_calls(file_, commands)
    file_referables = get_referables(file_)
    file_references = get_references(file_)
    file_citations = get_citations(file_)

    return dict(file_calls=file_calls,
                file_referables=file_referables,
                file_references=file_references,
                file_citations=file_citations)

def merge(old, new):
    """
    Adds the contents of `new` into `old`
    """
    assert type(old) == defaultdict
    assert old.default_factory == list
    assert type(new) == defaultdict
    assert new.default_factory == list

    for k,v in new.items():
        old[k].extend(v)

    return old

def main():
    args = parse_args()
    inputdir = args.dir
    mainfile = os.path.join(inputdir, args.main)
    commands = args.commands

    referables = defaultdict(list)
    references = defaultdict(list)
    citations = defaultdict(list)
    citables = defaultdict(list)
    calls = defaultdict(list)

    # Determine (recursively) which files are actually used.
    # e.g., main file and files included via \include{} or \input{})
    # Assumes that there are no cycles.
    included_tex_files = {mainfile}.union(get_all_included_tex_files(mainfile, inputdir))
    bibfiles = [os.path.join(inputdir, b) for b in get_bibfiles(mainfile)]

    # Because .bib files can overwrite each other, read in .bib files before traversing the directory tree in an unknown order.
    for bibfile in bibfiles:
        # TODO: resolve entries when mentioned in multiple files and issue warning that entry appears multiple times
        file_citables = get_citables(bibfile)
        merge(citables, file_citables)

    for dir_, subdirs, files in os.walk(inputdir):
        for file_ in files:
            file_ = os.path.join(inputdir, dir_, file_)

            if file_ in included_tex_files:
                assert file_.endswith('.tex')

                lint = lint_file(file_, commands)

                calls = merge(calls, lint["file_calls"])
                referables = merge(referables, lint["file_referables"])
                references = merge(references, lint["file_references"])
                citations = merge(citations, lint["file_citations"])


    # Print out logs
    print(f"version={_version}")
    print("*"*12 + " Warnings " + "*"*12)
    for call_k, instances in calls.items():
        assert len(set([hash(instance) for instance in instances])) == 1  # all instances have same name
        print(f"[CMD] {instances[0].name} found:")
        for instance in instances:
            print(f"\t{location(instance, fmt=True)}")

    if args.mult_bib_entry:
        for citable_k, definition_instances in citables.items():
            if len(definition_instances) > 1:
                assert len(set([hash(instance) for instance in definition_instances])) == 1  # all instances have same identifier
                print(f"[MULT-BIB-ENTRY] {hash(definition_instances[0], fmt=True)} is defined more than once:")
                for instance in definition_instances:
                    print(f"\t{location(instance, fmt=True)}")

    if args.mult_def:
        for referable_k, definition_instances in referables.items():
            if len(definition_instances) > 1:
                assert len(set([hash(instance) for instance in definition_instances])) == 1  # all instances have same identifier
                print(f"[MULT-DEF] {hash(definition_instances[0], fmt=True)} is defined more than once:")
                for instance in definition_instances:
                    print(f"\t{location(instance, fmt=True)}")


    # Are any document elements created but not referenced?
    if args.not_referenced:
        for referable_k, definition_instances in referables.items():
            assert len(set([hash(instance) for instance in definition_instances])) == 1  # all instances have same identifier
            num_references = len(references[referable_k])
            is_referenced = (num_references > 0)
            if not is_referenced:
                print(f"[NOT-REFERENCED] {hash(definition_instances[0], fmt=True)} is not referenced. Introduced at {location(definition_instances[0], fmt=True)}")


    # Are any references unresolvable to document elements? (e.g., document element is unlabeled or does not exist)
    if args.no_reference:
        for reference_k, instances in references.items():
            assert len(set([hash(instance) for instance in instances])) <= 1  # all instances have same identifier
            num_referables = len(referables[reference_k])
            if num_referables == 0:
                print(f"[NO-REFERENCE] {hash(instances[0], fmt=True)} points to unlabeled/nonexistent document element. Referenced at:")
                for instance in instances:
                    print(f"\t{location(instance, fmt=True)}")


    # Are any bibliography entries not cited in the document?
    if args.not_cited:
        for citable_k, definition_instances in citables.items():
            assert len(set([hash(instance) for instance in definition_instances])) == 1  # all instances have same identifier
            num_citations = len(citations[citable_k])
            is_cited = (num_citations > 0)
            if not is_cited:
                print(f"[NOT-CITED] {definition_instances[0].name} is not cited. Introduced at {definition_instances[0].document}:{definition_instances[0].line}")

    # Are any citations missing in the bibliography?
    if args.no_bib_entry:
        for citation_k, instances in citations.items():
            assert len(set([hash(instance) for instance in instances])) <= 1  # all instances have same identifier
            num_citables = len(citables[citation_k])
            if num_citables == 0:
                print(f"[NO-BIB-ENTRY] {instances[0].name} is not listed in the bibliography. Cited at {instances[0].document}:{instances[0].line}")

    print()

    if args.print_citation_counts:
        print("*"*10 + " Citation Counts " + "*"*10)
        citation_counts = {citation.name: len(instances) for citation,instances in citations.items() if len(instances) > 0}
        for k,v in sorted(citation_counts.items()):
            print(f"{k}:\t{v}")
        print()

    if args.print_document_element_summary:
        kind2numreferables = Counter([referable.kind for referable in referables.keys()])

        kind2numreferenced = Counter()
        for k in referables.keys():
            num_references = len(references[k])
            is_referenced = (num_references > 0)
            if is_referenced:
                kind2numreferenced[k.kind] += 1
        
        kinds = set(kind2numreferables.keys()).union(set(kind2numreferenced.keys()))
        max_kind_len = max(len(k) for k in kinds)
        
        _kind_tag, _numlabeled_tag, _numreferenced_tag, _missing_tag = "Kind", "Labl'd", "Ref'd", "**Unref**"
        tags = [_kind_tag, _numlabeled_tag, _numreferenced_tag, _missing_tag]
        column_widths = [len(tag) for tag in tags]

        # TODO: adjust column widths based on max length of string in each column (i.e., long values may overrun column)
        print("*"*8 + " Document Element Summary " + "*"*8)
        header = f"{_kind_tag:{max_kind_len}}\t{_numlabeled_tag:{column_widths[1]}}\t{_numreferenced_tag:{column_widths[2]}}\t{_missing_tag:{column_widths[3]}}"
        print(header)
        print("-"*41)  # TODO: figure out the right way to get len("---...---") == len(header)
        for kind in sorted(kinds):
            numlabeled = kind2numreferables[kind]
            numreferenced = kind2numreferenced[kind]
            missing = numlabeled - numreferenced if numlabeled - numreferenced > 0 else ''
            print(f"{kind:{max_kind_len}}\t{numlabeled:{column_widths[1]}}\t{numreferenced:{column_widths[2]}}\t{missing:{column_widths[3]}}")
        print()

if __name__ == "__main__":
    main()
