#!/usr/bin/env python3
"""
Pre-registered secondary analysis: does marcescent (persistent standing dead)
plant litter decompose more slowly than directly shed litter, and does the
difference depend on functional group?

Automated executor (arm B). Self-contained: downloads the pinned file from
Figshare, verifies md5, runs the pre-registered tests, writes every figure,
table, and results.json. Re-running reproduces every number.

Dataset: Figshare 10.6084/m9.figshare.25062800 (CC-BY-4.0),
         "Angst et al. 2024 Functional Ecology raw data",
         file "DataMarcescencedecomposition fin.xlsx" (md5 0c66ac4ef3d82ef09ed2c0332a4796b6).

Pre-registered tests (run exactly these):
  T1  decomposition (LossMass %) marcescent vs shed, paired by species  : Wilcoxon signed-rank
  T2  per-species (marcescent - shed) difference by functional group    : Mann-Whitney U (Forb vs Grass)
  T3a within each functional group, marcescent vs shed                   : Wilcoxon signed-rank
  T3b microbial biomass (Cmic) vs decomposition                         : Spearman correlation
  Multiplicity across {T2, T3a-forb, T3a-grass, T3b}: Benjamini-Hochberg (FDR) and Holm.
"""
import hashlib
import io
import json
import os
import urllib.request

import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy import stats

np.random.seed(0)

EXPECTED_MD5 = "0c66ac4ef3d82ef09ed2c0332a4796b6"
URL = "https://ndownloader.figshare.com/files/44230574"
OUT = os.path.dirname(os.path.abspath(__file__))
FIG = os.path.join(OUT, "figures")
TAB = os.path.join(OUT, "tables")
os.makedirs(FIG, exist_ok=True)
os.makedirs(TAB, exist_ok=True)

CB = {"M": "#0072B2", "S": "#E69F00", "Forb": "#009E73", "Grass": "#CC79A7", "Legume": "#999999"}


def fetch():
    raw = urllib.request.urlopen(URL, timeout=120).read()
    md5 = hashlib.md5(raw).hexdigest()
    if md5 != EXPECTED_MD5:
        raise SystemExit(f"md5 mismatch: expected {EXPECTED_MD5}, got {md5}")
    return raw, md5


def bh_holm(pvals):
    p = np.asarray(pvals, float)
    m = len(p)
    order = np.argsort(p)
    # Benjamini-Hochberg
    bh = np.empty(m)
    prev = 1.0
    for rank, idx in enumerate(order[::-1]):
        i = m - rank
        prev = min(prev, p[idx] * m / i)
        bh[idx] = min(prev, 1.0)
    # Holm
    holm = np.empty(m)
    prev = 0.0
    for rank, idx in enumerate(order):
        prev = max(prev, (m - rank) * p[idx])
        holm[idx] = min(prev, 1.0)
    return bh.tolist(), holm.tolist()


def rank_biserial_mwu(u, n1, n2):
    return 1.0 - (2.0 * u) / (n1 * n2)


def main():
    raw, md5 = fetch()
    d = pd.read_excel(io.BytesIO(raw), sheet_name="Decomposition")
    d = d[["Species", "Family", "Fgroup", "Replication", "Littertype", "CN", "LossMass [%]"]].copy()
    d = d.rename(columns={"LossMass [%]": "lossmass"})
    ms = d[d["Littertype"].isin(["M", "S"])].dropna(subset=["lossmass"]).copy()

    # species-level means per littertype (pairs by species, guards pseudoreplication)
    sp = ms.groupby(["Species", "Fgroup", "Littertype"])["lossmass"].mean().reset_index()
    wide = sp.pivot_table(index=["Species", "Fgroup"], columns="Littertype", values="lossmass").dropna(subset=["M", "S"]).reset_index()
    wide["diff"] = wide["M"] - wide["S"]

    results = {"dataset": {"figshare_id": "25062800", "md5": md5, "n_species_paired": int(len(wide))}}

    # ---- T1: marcescent vs shed, paired by species -------------------------
    w = stats.wilcoxon(wide["M"], wide["S"], method="approx")
    z = w.zstatistic
    r_t1 = abs(z) / np.sqrt(len(wide))
    n_M_gt_S = int((wide["diff"] > 0).sum())
    n_S_gt_M = int((wide["diff"] < 0).sum())
    results["T1"] = {
        "test": "Wilcoxon signed-rank (species-level, marcescent vs shed)",
        "W": float(w.statistic), "z": float(z), "p": float(w.pvalue), "rank_biserial_r": float(r_t1),
        "median_M": float(wide["M"].median()), "median_S": float(wide["S"].median()),
        "mean_M": float(wide["M"].mean()), "mean_S": float(wide["S"].mean()),
        "median_paired_diff": float(wide["diff"].median()),
        "n_species": int(len(wide)), "n_marcescent_slower": n_S_gt_M, "n_marcescent_faster": n_M_gt_S,
    }

    # ---- T2: per-species diff by functional group (Forb vs Grass) ----------
    forb = wide.loc[wide["Fgroup"] == "Forb", "diff"]
    grass = wide.loc[wide["Fgroup"] == "Grass", "diff"]
    u2 = stats.mannwhitneyu(forb, grass, alternative="two-sided")
    results["T2"] = {
        "test": "Mann-Whitney U on per-species (marcescent - shed) difference, Forb vs Grass",
        "U": float(u2.statistic), "p": float(u2.pvalue),
        "rank_biserial_r": float(rank_biserial_mwu(u2.statistic, len(forb), len(grass))),
        "mean_diff_forb": float(forb.mean()), "mean_diff_grass": float(grass.mean()),
        "n_forb": int(len(forb)), "n_grass": int(len(grass)),
        "note": f"Legume excluded from the contrast (n={int((wide['Fgroup']=='Legume').sum())} species, too few).",
    }

    # ---- T3a: within-group marcescent vs shed ------------------------------
    t3a = {}
    for g in ["Forb", "Grass"]:
        sub = wide[wide["Fgroup"] == g]
        wg = stats.wilcoxon(sub["M"], sub["S"], method="approx")
        t3a[g] = {"W": float(wg.statistic), "p": float(wg.pvalue), "n": int(len(sub)),
                  "median_M": float(sub["M"].median()), "median_S": float(sub["S"].median())}
    results["T3a"] = t3a

    # ---- T3b: microbial biomass (Cmic) vs decomposition --------------------
    plfa = pd.read_excel(io.BytesIO(raw), sheet_name="PLFA")
    plfa = plfa.rename(columns={"Cmic [µg C/g]": "Cmic", "Litter": "Littertype", "replicate": "Replication"})
    plfa = plfa[["Species", "Littertype", "Replication", "Cmic"]].dropna(subset=["Cmic"])
    plfa["Replication"] = pd.to_numeric(plfa["Replication"].astype(str).str.extract(r"(\d+)")[0], errors="coerce")
    dj = ms.copy()
    dj["Replication"] = pd.to_numeric(dj["Replication"], errors="coerce")
    joined = dj.merge(plfa, on=["Species", "Littertype", "Replication"], how="inner").dropna(subset=["Cmic", "lossmass"])
    rho_t3b, p_t3b = stats.spearmanr(joined["Cmic"], joined["lossmass"])
    results["T3b"] = {"test": "Spearman: microbial biomass Cmic vs decomposition (LossMass %)",
                      "rho": float(rho_t3b), "p": float(p_t3b), "n": int(len(joined))}

    # ---- multiplicity across the family {T2, T3a-forb, T3a-grass, T3b} -----
    fam = [("T2", results["T2"]["p"]), ("T3a-Forb", t3a["Forb"]["p"]),
           ("T3a-Grass", t3a["Grass"]["p"]), ("T3b", results["T3b"]["p"])]
    bh, holm = bh_holm([p for _, p in fam])
    results["multiplicity"] = {name: {"raw_p": p, "bh_p": bh[i], "holm_p": holm[i]}
                               for i, (name, p) in enumerate(fam)}

    with open(os.path.join(OUT, "results.json"), "w") as fh:
        json.dump(results, fh, indent=2)

    # ---- figures -----------------------------------------------------------
    # fig-1: paired species-level marcescent vs shed
    fig, ax = plt.subplots(figsize=(5.2, 4.2), dpi=150)
    for _, row in wide.iterrows():
        ax.plot([0, 1], [row["S"], row["M"]], color="0.7", lw=0.7, zorder=1)
    ax.scatter(np.zeros(len(wide)), wide["S"], color=CB["S"], label="shed", zorder=2, s=18)
    ax.scatter(np.ones(len(wide)), wide["M"], color=CB["M"], label="marcescent", zorder=2, s=18)
    ax.set_xticks([0, 1]); ax.set_xticklabels(["shed", "marcescent"])
    ax.set_ylabel("Litter mass loss (%)")
    ax.set_title(f"Marcescent vs shed decomposition by species (n={len(wide)})\nWilcoxon p={results['T1']['p']:.3f}")
    ax.legend(frameon=False)
    fig.tight_layout(); fig.savefig(os.path.join(FIG, "fig-1-marcescent-vs-shed.png")); plt.close(fig)

    # fig-2: per-species diff by functional group
    fig, ax = plt.subplots(figsize=(5.2, 4.2), dpi=150)
    groups = ["Forb", "Grass", "Legume"]
    data = [wide.loc[wide["Fgroup"] == g, "diff"].values for g in groups]
    bp = ax.boxplot(data, tick_labels=[f"{g}\n(n={len(x)})" for g, x in zip(groups, data)], showfliers=False)
    for i, g in enumerate(groups):
        y = data[i]; ax.scatter(np.full(len(y), i + 1) + np.random.uniform(-0.08, 0.08, len(y)), y, color=CB[g], s=16, zorder=3)
    ax.axhline(0, color="0.6", lw=0.8, ls="--")
    ax.set_ylabel("Per-species (marcescent − shed) mass loss (pp)")
    ax.set_title(f"Functional-group effect on the difference\nForb vs Grass Mann-Whitney p={results['T2']['p']:.3f}")
    fig.tight_layout(); fig.savefig(os.path.join(FIG, "fig-2-diff-by-functional-group.png")); plt.close(fig)

    # fig-3: microbial biomass vs decomposition
    fig, ax = plt.subplots(figsize=(5.2, 4.2), dpi=150)
    for lt, name in [("M", "marcescent"), ("S", "shed")]:
        s = joined[joined["Littertype"] == lt]
        ax.scatter(s["Cmic"], s["lossmass"], color=CB[lt], s=18, label=name)
    ax.set_xlabel("Microbial biomass Cmic (µg C/g)"); ax.set_ylabel("Litter mass loss (%)")
    ax.set_title(f"Microbial biomass vs decomposition\nSpearman rho={rho_t3b:.2f}, p={p_t3b:.3f}, n={len(joined)}")
    ax.legend(frameon=False)
    fig.tight_layout(); fig.savefig(os.path.join(FIG, "fig-3-microbial-vs-decomposition.png")); plt.close(fig)

    # ---- tables ------------------------------------------------------------
    t1 = wide[["Species", "Fgroup", "M", "S", "diff"]].rename(columns={"M": "lossmass_marcescent", "S": "lossmass_shed", "diff": "marcescent_minus_shed"})
    t1.to_csv(os.path.join(TAB, "tbl-1-species-level.csv"), index=False)

    t2 = pd.DataFrame([
        {"fgroup": g, "n_species": len(wide[wide["Fgroup"] == g]),
         "mean_diff_pp": wide.loc[wide["Fgroup"] == g, "diff"].mean(),
         "within_group_wilcoxon_p": t3a.get(g, {}).get("p")}
        for g in groups
    ])
    t2.to_csv(os.path.join(TAB, "tbl-2-by-functional-group.csv"), index=False)

    t3 = pd.DataFrame([
        {"test": name, "raw_p": results["multiplicity"][name]["raw_p"],
         "bh_p": results["multiplicity"][name]["bh_p"], "holm_p": results["multiplicity"][name]["holm_p"]}
        for name, _ in fam
    ])
    t3.to_csv(os.path.join(TAB, "tbl-3-multiplicity.csv"), index=False)

    # Frictionless datapackage
    dp = {"name": "marcescence-decomposition-tables", "resources": []}
    for path, name in [("tbl-1-species-level.csv", "species-level"), ("tbl-2-by-functional-group.csv", "by-functional-group"), ("tbl-3-multiplicity.csv", "multiplicity")]:
        cols = pd.read_csv(os.path.join(TAB, path)).columns
        dp["resources"].append({"name": name, "path": path, "schema": {"fields": [{"name": c, "type": "number" if c not in ("Species", "Fgroup", "fgroup", "test") else "string"} for c in cols]}})
    with open(os.path.join(TAB, "datapackage.json"), "w") as fh:
        json.dump(dp, fh, indent=2)

    print("md5", md5)
    print("T1 Wilcoxon p=%.4f z=%.3f r=%.3f | mean M=%.1f S=%.1f | marcescent slower in %d/%d species"
          % (results["T1"]["p"], results["T1"]["z"], r_t1, results["T1"]["mean_M"], results["T1"]["mean_S"], n_S_gt_M, len(wide)))
    print("T2 Forb vs Grass U=%.1f p=%.4f | mean diff forb=%.2f grass=%.2f" % (u2.statistic, u2.pvalue, forb.mean(), grass.mean()))
    print("T3a", {g: round(t3a[g]["p"], 4) for g in t3a})
    print("T3b Spearman rho=%.3f p=%.4f n=%d" % (rho_t3b, p_t3b, len(joined)))
    print("BH", [round(x, 4) for x in bh], "Holm", [round(x, 4) for x in holm])


if __name__ == "__main__":
    main()
