#!/usr/bin/env python3
"""
Pre-registered secondary analysis of a volant-seabird foraging dataset.

Question (fixed):
  During foraging bouts, does the volant seabird's prey-catch count differ
  between the two interaction modes (flight vs surface), and does catch
  increase with the number of conspecific penguins present?

Pre-registered tests (run exactly these, nothing else):
  T1  catch ~ interaction mode        : Mann-Whitney U (+ rank-biserial r)
  T2  catch ~ conspmax                : Spearman rank correlation
  T3a catch ~ conspinit               : Spearman (robustness of T2)
  T3b catch ~ conspmax within mode    : Spearman, per interaction mode
  Multiplicity correction across the correlation family {T2, T3a,
  T3b-flight, T3b-surface}: Benjamini-Hochberg (FDR) and Holm.
  Sensitivity to pseudoreplication: per-ID aggregate re-run of T1 and T2.

The script downloads the pinned file from Zenodo, verifies its md5, runs
the tests, and writes every figure, table, results.json, and the tables
datapackage. Re-running reproduces every number.
"""

import hashlib
import io
import json
import os
import subprocess
import sys
import urllib.request

import numpy as np
import pandas as pd
from scipy import stats

# ----------------------------------------------------------------------
# Determinism
# ----------------------------------------------------------------------
SEED = 20240521
np.random.seed(SEED)
RNG = np.random.default_rng(SEED)

HERE = os.path.dirname(os.path.abspath(__file__))
FIG_DIR = os.path.join(HERE, "figures")
TBL_DIR = os.path.join(HERE, "tables")
os.makedirs(FIG_DIR, exist_ok=True)
os.makedirs(TBL_DIR, exist_ok=True)

# ----------------------------------------------------------------------
# 1. Fetch + verify
# ----------------------------------------------------------------------
DATA_FNAME = "upforgrabs_lmm_data.csv"
EXPECTED_MD5 = "37f340592899c62ea73fc4d8eb010c85"
URLS = [
    "https://zenodo.org/records/4964380/files/upforgrabs_lmm_data.csv?download=1",
    "https://zenodo.org/api/records/4964380/files/upforgrabs_lmm_data.csv/content",
]


def fetch_data():
    local = os.path.join(HERE, DATA_FNAME)
    raw = None
    if os.path.exists(local):
        with open(local, "rb") as fh:
            raw = fh.read()
        if hashlib.md5(raw).hexdigest() != EXPECTED_MD5:
            raw = None
    if raw is None:
        last_err = None
        for url in URLS:
            try:
                with urllib.request.urlopen(url, timeout=60) as resp:
                    raw = resp.read()
                break
            except Exception as e:  # noqa: BLE401
                last_err = e
                raw = None
        if raw is None:
            raise RuntimeError(f"Could not download {DATA_FNAME}: {last_err}")
        with open(local, "wb") as fh:
            fh.write(raw)
    observed_md5 = hashlib.md5(raw).hexdigest()
    if observed_md5 != EXPECTED_MD5:
        raise ValueError(
            f"md5 mismatch: expected {EXPECTED_MD5}, observed {observed_md5}"
        )
    df = pd.read_csv(io.BytesIO(raw), index_col=0)
    return df, observed_md5


# ----------------------------------------------------------------------
# Effect-size helpers
# ----------------------------------------------------------------------
def rank_biserial_from_U(U1, n1, n2):
    """Rank-biserial correlation for Mann-Whitney.

    r = 2 * (U1 / (n1*n2)) - 1, where U1 is the U statistic associated with
    group 1. r > 0 means group 1 tends to exceed group 2. Range [-1, 1].
    """
    return 2.0 * (U1 / (n1 * n2)) - 1.0


def main():
    df, observed_md5 = fetch_data()

    # ------------------------------------------------------------------
    # 2. Profile
    # ------------------------------------------------------------------
    n_rows = len(df)
    n_id = df["ID"].nunique()
    n_bout = df["IDbout"].nunique()
    grp = df.groupby("intermode")
    n_by_mode = df["intermode"].value_counts().to_dict()
    n_zeros = int((df["catch"] == 0).sum())
    max_bouts_per_id = int(df["ID"].value_counts().max())

    profile = {
        "n_rows": int(n_rows),
        "n_unique_ID": int(n_id),
        "n_unique_IDbout": int(n_bout),
        "n_missing_total": int(df.isna().sum().sum()),
        "n_by_intermode": {k: int(v) for k, v in n_by_mode.items()},
        "catch_n_zeros": n_zeros,
        "catch_zero_fraction": round(n_zeros / n_rows, 4),
        "catch_min": int(df["catch"].min()),
        "catch_max": int(df["catch"].max()),
        "catch_median": float(df["catch"].median()),
        "catch_mean": round(float(df["catch"].mean()), 4),
        "max_bouts_per_ID": max_bouts_per_id,
        "n_ID_with_repeat": int((df["ID"].value_counts() > 1).sum()),
        "elapsed_median_s": float(df["elapsed"].median()),
        "conspmax_median": float(df["conspmax"].median()),
        "conspinit_median": float(df["conspinit"].median()),
    }

    # Per-group summary table (tbl-1)
    g_summary = (
        df.groupby("intermode")["catch"]
        .agg(n="count", median="median", mean="mean", sd="std",
             min="min", max="max", n_zeros=lambda s: int((s == 0).sum()))
        .reset_index()
    )
    g_summary["mean"] = g_summary["mean"].round(4)
    g_summary["sd"] = g_summary["sd"].round(4)
    tbl1_path = os.path.join(TBL_DIR, "tbl-1-group-summary.csv")
    g_summary.to_csv(tbl1_path, index=False)

    flight = df.loc[df["intermode"] == "flight", "catch"].to_numpy()
    surface = df.loc[df["intermode"] == "surface", "catch"].to_numpy()
    n_flight, n_surface = len(flight), len(surface)

    # ------------------------------------------------------------------
    # T1 : Mann-Whitney U, catch flight vs surface
    # ------------------------------------------------------------------
    U_flight, p_t1 = stats.mannwhitneyu(flight, surface, alternative="two-sided")
    rbc = rank_biserial_from_U(U_flight, n_flight, n_surface)
    # U for surface (complementary)
    U_surface = n_flight * n_surface - U_flight
    t1 = {
        "test": "Mann-Whitney U (two-sided)",
        "comparison": "catch: flight vs surface",
        "U_flight": float(U_flight),
        "U_surface": float(U_surface),
        "p_raw": float(p_t1),
        "effect_size_rank_biserial": round(float(rbc), 4),
        "n_flight": int(n_flight),
        "n_surface": int(n_surface),
        "median_flight": float(np.median(flight)),
        "median_surface": float(np.median(surface)),
        "mean_flight": round(float(np.mean(flight)), 4),
        "mean_surface": round(float(np.mean(surface)), 4),
    }

    # ------------------------------------------------------------------
    # T2 : Spearman catch ~ conspmax
    # ------------------------------------------------------------------
    rho_t2, p_t2 = stats.spearmanr(df["catch"], df["conspmax"])
    t2 = {
        "test": "Spearman rank correlation",
        "relation": "catch ~ conspmax",
        "rho": round(float(rho_t2), 4),
        "p_raw": float(p_t2),
        "n": int(n_rows),
    }

    # ------------------------------------------------------------------
    # T3a : Spearman catch ~ conspinit
    # ------------------------------------------------------------------
    rho_t3a, p_t3a = stats.spearmanr(df["catch"], df["conspinit"])
    t3a = {
        "test": "Spearman rank correlation",
        "relation": "catch ~ conspinit",
        "rho": round(float(rho_t3a), 4),
        "p_raw": float(p_t3a),
        "n": int(n_rows),
    }

    # ------------------------------------------------------------------
    # T3b : Spearman catch ~ conspmax within each mode
    # ------------------------------------------------------------------
    t3b = {}
    for mode in ["flight", "surface"]:
        sub = df[df["intermode"] == mode]
        rho, p = stats.spearmanr(sub["catch"], sub["conspmax"])
        t3b[mode] = {
            "test": "Spearman rank correlation",
            "relation": f"catch ~ conspmax | intermode={mode}",
            "rho": round(float(rho), 4),
            "p_raw": float(p),
            "n": int(len(sub)),
        }

    # ------------------------------------------------------------------
    # Multiplicity correction across the correlation family
    # ------------------------------------------------------------------
    family = [
        ("T2_catch~conspmax", p_t2),
        ("T3a_catch~conspinit", p_t3a),
        ("T3b_flight", t3b["flight"]["p_raw"]),
        ("T3b_surface", t3b["surface"]["p_raw"]),
    ]
    names = [f[0] for f in family]
    praw = np.array([f[1] for f in family])
    holm = holm_correction(praw)
    bh = benjamini_hochberg(praw)
    correction = {
        "family": names,
        "p_raw": [float(x) for x in praw],
        "p_holm": [float(x) for x in holm],
        "p_bh_fdr": [float(x) for x in bh],
        "note": ("T1 (Mann-Whitney, different response structure) is reported "
                 "separately and not folded into the correlation family; its "
                 "raw p is the only catch-by-mode test."),
    }

    # ------------------------------------------------------------------
    # Sensitivity: per-ID aggregate (guard vs pseudoreplication)
    # ------------------------------------------------------------------
    # One row per bird: median catch, median conspmax, and the modal
    # interaction mode (bird assigned to a mode only if unambiguous).
    per_id = df.groupby("ID").agg(
        catch_median=("catch", "median"),
        conspmax_median=("conspmax", "median"),
        n_bouts=("IDbout", "count"),
    ).reset_index()
    mode_by_id = (
        df.groupby("ID")["intermode"]
        .agg(lambda s: s.mode().iloc[0] if s.nunique() == 1 else "mixed")
    )
    per_id["mode"] = per_id["ID"].map(mode_by_id).values

    # T1 sensitivity: per-ID median catch by unambiguous mode
    fl_id = per_id.loc[per_id["mode"] == "flight", "catch_median"].to_numpy()
    su_id = per_id.loc[per_id["mode"] == "surface", "catch_median"].to_numpy()
    if len(fl_id) >= 1 and len(su_id) >= 1:
        U_id, p_id = stats.mannwhitneyu(fl_id, su_id, alternative="two-sided")
        rbc_id = rank_biserial_from_U(U_id, len(fl_id), len(su_id))
    else:
        U_id, p_id, rbc_id = np.nan, np.nan, np.nan
    # T2 sensitivity: per-ID median catch ~ median conspmax
    rho_id, p_rho_id = stats.spearmanr(per_id["catch_median"],
                                       per_id["conspmax_median"])
    sensitivity = {
        "n_ID": int(len(per_id)),
        "n_ID_flight": int(len(fl_id)),
        "n_ID_surface": int(len(su_id)),
        "n_ID_mixed": int((per_id["mode"] == "mixed").sum()),
        "T1_perID_MannWhitneyU": (None if np.isnan(U_id) else float(U_id)),
        "T1_perID_p": (None if np.isnan(p_id) else float(p_id)),
        "T1_perID_rank_biserial": (None if np.isnan(rbc_id)
                                   else round(float(rbc_id), 4)),
        "T1_perID_median_flight": (None if len(fl_id) == 0
                                   else float(np.median(fl_id))),
        "T1_perID_median_surface": (None if len(su_id) == 0
                                    else float(np.median(su_id))),
        "T2_perID_spearman_rho": round(float(rho_id), 4),
        "T2_perID_p": float(p_rho_id),
    }
    tbl2_path = os.path.join(TBL_DIR, "tbl-2-per-id-sensitivity.csv")
    per_id_out = per_id.copy()
    per_id_out["catch_median"] = per_id_out["catch_median"].round(4)
    per_id_out["conspmax_median"] = per_id_out["conspmax_median"].round(4)
    per_id_out.to_csv(tbl2_path, index=False)

    # Correlation-family results table (tbl-3)
    tbl3 = pd.DataFrame({
        "test": names,
        "rho": [t2["rho"], t3a["rho"], t3b["flight"]["rho"],
                t3b["surface"]["rho"]],
        "n": [t2["n"], t3a["n"], t3b["flight"]["n"], t3b["surface"]["n"]],
        "p_raw": [float(x) for x in praw],
        "p_holm": [round(float(x), 4) for x in holm],
        "p_bh_fdr": [round(float(x), 4) for x in bh],
    })
    tbl3_path = os.path.join(TBL_DIR, "tbl-3-correlation-family.csv")
    tbl3.to_csv(tbl3_path, index=False)

    # ------------------------------------------------------------------
    # Figures
    # ------------------------------------------------------------------
    make_figures(df, flight, surface, t1, t2, t3b, tbl3)

    # ------------------------------------------------------------------
    # tables datapackage (Frictionless)
    # ------------------------------------------------------------------
    write_datapackage()

    # ------------------------------------------------------------------
    # results.json
    # ------------------------------------------------------------------
    results = {
        "dataset": {
            "source": "Zenodo record 4964380 (DOI 10.5061/dryad.5q04b32)",
            "file": DATA_FNAME,
            "md5_expected": EXPECTED_MD5,
            "md5_observed": observed_md5,
            "license": "CC0",
        },
        "seed": SEED,
        "profile": profile,
        "T1_catch_by_mode": t1,
        "T2_catch_vs_conspmax": t2,
        "T3a_catch_vs_conspinit": t3a,
        "T3b_catch_vs_conspmax_within_mode": t3b,
        "multiplicity_correction": correction,
        "sensitivity_perID": sensitivity,
    }
    with open(os.path.join(HERE, "results.json"), "w") as fh:
        json.dump(results, fh, indent=2)

    print("md5 observed:", observed_md5)
    print("T1 U_flight=%.1f p=%.4g rbc=%.3f" %
          (t1["U_flight"], t1["p_raw"], t1["effect_size_rank_biserial"]))
    print("T2 rho=%.3f p=%.4g" % (t2["rho"], t2["p_raw"]))
    print("T3a rho=%.3f p=%.4g" % (t3a["rho"], t3a["p_raw"]))
    print("T3b flight rho=%.3f p=%.4g | surface rho=%.3f p=%.4g" %
          (t3b["flight"]["rho"], t3b["flight"]["p_raw"],
           t3b["surface"]["rho"], t3b["surface"]["p_raw"]))
    print("BH:", correction["p_bh_fdr"])
    print("Holm:", correction["p_holm"])
    print("Sensitivity T2 per-ID rho=%.3f p=%.4g" %
          (sensitivity["T2_perID_spearman_rho"], sensitivity["T2_perID_p"]))
    return results


# ----------------------------------------------------------------------
# Multiplicity corrections (self-contained; deterministic)
# ----------------------------------------------------------------------
def holm_correction(pvals):
    p = np.asarray(pvals, float)
    m = len(p)
    order = np.argsort(p)
    adj = np.empty(m)
    running = 0.0
    for i, idx in enumerate(order):
        val = (m - i) * p[idx]
        running = max(running, val)
        adj[idx] = min(running, 1.0)
    return adj


def benjamini_hochberg(pvals):
    p = np.asarray(pvals, float)
    m = len(p)
    order = np.argsort(p)
    adj = np.empty(m)
    prev = 1.0
    for rank in range(m - 1, -1, -1):
        idx = order[rank]
        val = p[idx] * m / (rank + 1)
        prev = min(prev, val)
        adj[idx] = min(prev, 1.0)
    return adj


# ----------------------------------------------------------------------
# Figures
# ----------------------------------------------------------------------
def make_figures(df, flight, surface, t1, t2, t3b, tbl3):
    import matplotlib
    matplotlib.use("Agg")
    import matplotlib.pyplot as plt

    # Colourblind-safe (Okabe-Ito)
    C_FLIGHT = "#0072B2"   # blue
    C_SURFACE = "#E69F00"  # orange
    plt.rcParams.update({"figure.dpi": 150, "font.size": 11,
                         "axes.spines.top": False, "axes.spines.right": False})

    # -- fig-1 : catch by interaction mode --
    fig1, ax1 = plt.subplots(figsize=(5.2, 4.2))
    data = [flight, surface]
    positions = [1, 2]
    colors = [C_FLIGHT, C_SURFACE]
    bp = ax1.boxplot(data, positions=positions, widths=0.5,
                     patch_artist=True, showfliers=False,
                     medianprops=dict(color="black", linewidth=1.6))
    for patch, c in zip(bp["boxes"], colors):
        patch.set_facecolor(c)
        patch.set_alpha(0.35)
    for pos, arr, c in zip(positions, data, colors):
        jitter = (RNG.random(len(arr)) - 0.5) * 0.28
        ax1.scatter(np.full(len(arr), pos) + jitter, arr, s=32,
                    color=c, edgecolor="black", linewidth=0.5,
                    alpha=0.9, zorder=3)
    ax1.set_xticks(positions)
    ax1.set_xticklabels([f"flight\n(n={t1['n_flight']})",
                         f"surface\n(n={t1['n_surface']})"])
    ax1.set_ylabel("Prey caught per bout (count)")
    ax1.set_xlabel("Interaction mode")
    ax1.set_title("Prey-catch count by interaction mode\n"
                  f"Mann-Whitney U={t1['U_flight']:.0f}, "
                  f"p={t1['p_raw']:.3f}, r_rb={t1['effect_size_rank_biserial']:.2f}",
                  fontsize=10.5)
    fig1.tight_layout()
    fig1.savefig(os.path.join(FIG_DIR, "fig-1-catch-by-interaction-mode.png"),
                 dpi=150)
    plt.close(fig1)

    # -- fig-2 : catch vs conspmax --
    fig2, ax2 = plt.subplots(figsize=(5.4, 4.2))
    for mode, c in [("flight", C_FLIGHT), ("surface", C_SURFACE)]:
        sub = df[df["intermode"] == mode]
        jx = (RNG.random(len(sub)) - 0.5) * 0.3
        jy = (RNG.random(len(sub)) - 0.5) * 0.3
        ax2.scatter(sub["conspmax"] + jx, sub["catch"] + jy, s=36,
                    color=c, edgecolor="black", linewidth=0.5, alpha=0.85,
                    label=f"{mode} (n={len(sub)})")
    ax2.set_xlabel("Conspecific penguins present (max during bout, count)")
    ax2.set_ylabel("Prey caught per bout (count)")
    ax2.set_title("Prey catch vs. conspecific number\n"
                  f"Spearman rho={t2['rho']:.2f}, p={t2['p_raw']:.3f}, "
                  f"n={t2['n']}", fontsize=10.5)
    ax2.legend(frameon=False, fontsize=9)
    fig2.tight_layout()
    fig2.savefig(os.path.join(FIG_DIR, "fig-2-catch-vs-conspmax.png"), dpi=150)
    plt.close(fig2)

    # -- fig-3 : correlation-family forest (rho +/- BH annotation) --
    fig3, ax3 = plt.subplots(figsize=(6.0, 4.0))
    labels = ["catch~conspmax\n(all, T2)",
              "catch~conspinit\n(all, T3a)",
              "catch~conspmax\n(flight, T3b)",
              "catch~conspmax\n(surface, T3b)"]
    rhos = tbl3["rho"].to_numpy()
    ypos = np.arange(len(rhos))[::-1]
    bar_colors = [C_SURFACE, C_SURFACE, C_FLIGHT, C_SURFACE]
    ax3.barh(ypos, rhos, color=bar_colors, alpha=0.7, edgecolor="black")
    ax3.axvline(0, color="black", linewidth=0.8)
    for y, rho, praw, pbh, n in zip(ypos, rhos, tbl3["p_raw"], tbl3["p_bh_fdr"],
                                    tbl3["n"]):
        xoff = 0.02 if rho >= 0 else -0.02
        ha = "left" if rho >= 0 else "right"
        ax3.text(rho + xoff, y,
                 f"rho={rho:.2f}\np={praw:.3f}, p_BH={pbh:.3f} (n={n})",
                 va="center", ha=ha, fontsize=8)
    ax3.set_yticks(ypos)
    ax3.set_yticklabels(labels, fontsize=9)
    ax3.set_xlabel("Spearman rho (catch vs. conspecific number)")
    ax3.set_xlim(-1.05, 1.05)
    ax3.set_title("Robustness: catch–conspecific correlation across the "
                  "test family", fontsize=10.5)
    fig3.tight_layout()
    fig3.savefig(os.path.join(FIG_DIR, "fig-3-correlation-family-robustness.png"),
                 dpi=150)
    plt.close(fig3)


# ----------------------------------------------------------------------
# Frictionless datapackage for tables/
# ----------------------------------------------------------------------
def write_datapackage():
    def fields_of(path):
        d = pd.read_csv(path)
        tmap = {"int64": "integer", "float64": "number", "object": "string",
                "bool": "boolean"}
        return [{"name": c, "type": tmap.get(str(d[c].dtype), "string")}
                for c in d.columns]

    resources = []
    for fname in ["tbl-1-group-summary.csv", "tbl-2-per-id-sensitivity.csv",
                  "tbl-3-correlation-family.csv"]:
        path = os.path.join(TBL_DIR, fname)
        resources.append({
            "name": fname.replace(".csv", ""),
            "path": fname,
            "schema": {"fields": fields_of(path)},
        })
    dp = {
        "name": "seabird-foraging-facilitation-tables",
        "profile": "tabular-data-package",
        "resources": resources,
    }
    with open(os.path.join(TBL_DIR, "datapackage.json"), "w") as fh:
        json.dump(dp, fh, indent=2)


if __name__ == "__main__":
    main()
