"""
datasetpaper M1 analysis: asymmetric metabolic complementarity between the
endosymbiont Candidatus Bodocryptus vickermanii and its Bodo host.

Protocol: secondary-analysis 0.1 (protocols/secondary-analysis-0.1.json).
Source dataset: 10.6084/m9.figshare.31362613 (CC-BY-4.0), Supplementary Files
12 (endosymbiont + single-cell KEGG module completeness) and 13 (host genomes).

Self-contained and deterministic: downloads the two pinned files by Figshare id,
verifies md5, joins on KEGG module_accession, runs the pre-registered tests, and
writes figures, a derived table, and results.json. Re-running reproduces the
numbers (no randomness affects any statistic; a fixed seed only jitters fig-1).
"""
import hashlib, json, os, urllib.request
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from scipy.stats import wilcoxon, binomtest

RNG = np.random.default_rng(12345)
HERE = os.path.dirname(os.path.abspath(__file__))
RAW = os.path.join(HERE, "raw")
FIG = os.path.join(HERE, "figures")
TAB = os.path.join(HERE, "tables")
for d in (RAW, FIG, TAB):
    os.makedirs(d, exist_ok=True)

PINNED = {
    "file12.xlsx": ("61984579", "2f60622b02446792777fcde279cf376b"),
    "file13.xlsx": ("61984582", "ff04aa7c1f8b64e1f356b2e84ab00ca3"),
}
CELLS = ["F10", "B7", "A8", "A10", "B2", "G10", "H10"]
THRESHOLDS = [50, 67, 80]
THETA = 67  # primary completeness threshold (percent)


def fetch(name):
    fid, md5 = PINNED[name]
    path = os.path.join(RAW, name)
    if not os.path.exists(path):
        urllib.request.urlretrieve(f"https://ndownloader.figshare.com/files/{fid}", path)
    got = hashlib.md5(open(path, "rb").read()).hexdigest()
    assert got == md5, f"md5 mismatch for {name}: {got} != {md5}"
    return path


def classify(host, sym, theta):
    hc, sc = host >= theta, sym >= theta
    return {
        "both": int((hc & sc).sum()),
        "host_only": int((hc & ~sc).sum()),
        "symbiont_only": int((~hc & sc).sum()),
        "neither": int((~hc & ~sc).sum()),
    }


def main():
    f12 = pd.read_excel(fetch("file12.xlsx"))
    f13 = pd.read_excel(fetch("file13.xlsx"))

    keep12 = ["module_accession", "pathway_class", "pathway_name", "C.bv completeness"] + \
             [f"{c} completeness" for c in CELLS]
    keep13 = ["module_accession", "bsal completeness"] + [f"Bodo {c} completeness" for c in CELLS]
    m = pd.merge(f12[keep12], f13[keep13], on="module_accession", how="inner")
    n_shared = len(m)

    host = m["bsal completeness"].to_numpy(float)     # Bodo saltans host
    sym = m["C.bv completeness"].to_numpy(float)       # endosymbiont consensus
    diff = host - sym

    # T1: paired Wilcoxon signed-rank, host vs symbiont completeness
    w_stat, w_p = wilcoxon(host, sym, zero_method="wilcox", alternative="two-sided")
    n_host_gt = int((host > sym).sum())
    n_sym_gt = int((sym > host).sum())
    n_tie = int((host == sym).sum())

    # T2: thresholded categories + McNemar (exact binomial on discordant pairs)
    cats = {t: classify(host, sym, t) for t in THRESHOLDS}
    c67 = cats[THETA]
    b, c = c67["host_only"], c67["symbiont_only"]
    mcnemar_p = binomtest(min(b, c), b + c, 0.5, alternative="two-sided").pvalue if (b + c) > 0 else 1.0
    ratio = round(b / c, 2) if c > 0 else None

    # T3: consistency across the 7 single-cell host-symbiont pairs (host-only fraction at THETA)
    per_cell = {}
    for cell in CELLS:
        h = m[f"Bodo {cell} completeness"].to_numpy(float)
        s = m[f"{cell} completeness"].to_numpy(float)
        cc = classify(h, s, THETA)
        per_cell[cell] = {
            "host_only_frac": round(cc["host_only"] / n_shared, 4),
            "symbiont_only_frac": round(cc["symbiont_only"] / n_shared, 4),
            **cc,
        }
    ho_fracs = np.array([per_cell[c]["host_only_frac"] for c in CELLS])
    so_fracs = np.array([per_cell[c]["symbiont_only_frac"] for c in CELLS])

    # ---- derived table (tbl-1): per-module classification ----
    tbl = m[["module_accession", "pathway_class", "pathway_name"]].copy()
    tbl["host_bsal_completeness"] = host
    tbl["symbiont_Cbv_completeness"] = sym
    tbl["host_minus_symbiont"] = diff
    tbl[f"category_at_{THETA}"] = [
        "both" if (h >= THETA and s >= THETA) else
        "host_only" if (h >= THETA and s < THETA) else
        "symbiont_only" if (h < THETA and s >= THETA) else "neither"
        for h, s in zip(host, sym)
    ]
    tbl_path = os.path.join(TAB, "tbl-1-module-classification.csv")
    tbl.to_csv(tbl_path, index=False)

    # ---- fig-1: paired completeness scatter ----
    plt.figure(figsize=(5.2, 5))
    jx = sym + RNG.normal(0, 1.1, size=len(sym))
    jy = host + RNG.normal(0, 1.1, size=len(host))
    plt.scatter(jx, jy, s=18, alpha=0.55, edgecolor="none", color="#534AB7")
    plt.plot([0, 100], [0, 100], "--", color="#888780", lw=1)
    plt.xlabel("endosymbiont C. Bodocryptus vickermanii\nKEGG module completeness (%)")
    plt.ylabel("Bodo host (B. saltans)\nKEGG module completeness (%)")
    plt.title(f"Paired module completeness ({n_shared} shared KEGG modules)")
    plt.text(4, 92, f"host > symbiont: {n_host_gt}\nsymbiont > host: {n_sym_gt}\n"
                    f"median diff: {np.median(diff):.1f} pts\nWilcoxon p = {w_p:.2e}",
             fontsize=9, va="top",
             bbox=dict(boxstyle="round", fc="#EEEDFE", ec="#534AB7", alpha=0.9))
    plt.xlim(-3, 103); plt.ylim(-3, 103); plt.tight_layout()
    plt.savefig(os.path.join(FIG, "fig-1-paired-completeness.png"), dpi=150)
    plt.close()

    # ---- fig-2: category breakdown across thresholds ----
    plt.figure(figsize=(6, 4.2))
    order = ["both", "host_only", "symbiont_only", "neither"]
    colors = {"both": "#0F6E56", "host_only": "#534AB7", "symbiont_only": "#D85A30", "neither": "#B4B2A9"}
    x = np.arange(len(THRESHOLDS)); wbar = 0.19
    for i, k in enumerate(order):
        plt.bar(x + (i - 1.5) * wbar, [cats[t][k] for t in THRESHOLDS], wbar,
                label=k.replace("_", " "), color=colors[k])
    plt.xticks(x, [f"{t}%" for t in THRESHOLDS])
    plt.xlabel("completeness threshold"); plt.ylabel("number of KEGG modules")
    plt.title("Module category by threshold (host = B. saltans, symbiont = C.bv)")
    plt.legend(fontsize=8); plt.tight_layout()
    plt.savefig(os.path.join(FIG, "fig-2-category-breakdown.png"), dpi=150)
    plt.close()

    # ---- fig-3: per-cell consistency ----
    plt.figure(figsize=(6.2, 4))
    xc = np.arange(len(CELLS))
    plt.bar(xc - 0.2, ho_fracs * 100, 0.4, label="host-only", color="#534AB7")
    plt.bar(xc + 0.2, so_fracs * 100, 0.4, label="symbiont-only", color="#D85A30")
    plt.axhline(ho_fracs.mean() * 100, ls="--", color="#534AB7", lw=1)
    plt.xticks(xc, CELLS); plt.xlabel("single-cell host-symbiont pair")
    plt.ylabel(f"% of shared modules (threshold {THETA}%)")
    plt.title("Consistency of asymmetry across single cells")
    plt.legend(fontsize=8); plt.tight_layout()
    plt.savefig(os.path.join(FIG, "fig-3-per-cell-consistency.png"), dpi=150)
    plt.close()

    results = {
        "dataset_doi": "10.6084/m9.figshare.31362613",
        "n_shared_modules": n_shared,
        "mean_completeness": {"host_bsal": round(float(host.mean()), 2),
                              "symbiont_Cbv": round(float(sym.mean()), 2)},
        "T1_wilcoxon": {"statistic": round(float(w_stat), 1), "p_value": float(w_p),
                        "median_host_minus_symbiont": round(float(np.median(diff)), 2),
                        "host_gt_symbiont": n_host_gt, "symbiont_gt_host": n_sym_gt, "ties": n_tie},
        "T2_categories_by_threshold": cats,
        "T2_mcnemar_at_67": {"host_only": b, "symbiont_only": c,
                             "host_only_to_symbiont_only_ratio": ratio, "p_value": float(mcnemar_p)},
        "T3_consistency": {"per_cell": per_cell,
                           "host_only_frac_mean": round(float(ho_fracs.mean()), 4),
                           "host_only_frac_sd": round(float(ho_fracs.std(ddof=1)), 4),
                           "symbiont_only_frac_mean": round(float(so_fracs.mean()), 4),
                           "symbiont_only_frac_sd": round(float(so_fracs.std(ddof=1)), 4)},
        "figures": ["figures/fig-1-paired-completeness.png",
                    "figures/fig-2-category-breakdown.png",
                    "figures/fig-3-per-cell-consistency.png"],
        "table": "tables/tbl-1-module-classification.csv",
    }
    with open(os.path.join(HERE, "results.json"), "w") as fh:
        json.dump(results, fh, indent=2)
    print(json.dumps(results, indent=2))


if __name__ == "__main__":
    main()
