#!/usr/bin/env python3
"""
Pre-registered secondary analysis of the InvaCost database (economic costs of
biological invasions): how does the magnitude of recorded annual cost differ by
(T1) whether the cost is observed vs potential, (T2) environment, and (T3)
damage vs management cost type?

Automated executor (arm B) — ingested via the direct Figshare connector. The
Figshare file bytes are S3-served (blocked in sandboxed agents), so this runs
locally, md5-verifying the pinned file.

Dataset: Figshare 10.6084/m9.figshare.12668570 (CC-BY-4.0),
         "InvaCost: Economic cost estimates associated with biological invasions",
         file "InvaCost_database_v4.1.xlsx" (md5 1cdc52dd50753fa30c2f3e12fc61fc44).
Response: Cost_estimate_per_year_2017_USD_exchange_rate (the standardised annual
          cost recommended for magnitude comparisons), analysed on the log10 scale.

Pre-registered tests (Mann-Whitney U on log10 cost, with rank-biserial effect
sizes and bootstrap CIs on the median-cost ratio; family corrected BH + Holm):
  T1  Observed vs Potential (Implementation)
  T2  Terrestrial vs Aquatic (Environment)
  T3  Damage vs Management (Type_of_cost_merged)
"""
import hashlib
import io
import json
import os
import urllib.request

import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy import stats

np.random.seed(0)

EXPECTED_MD5 = "1cdc52dd50753fa30c2f3e12fc61fc44"
URL = "https://ndownloader.figshare.com/files/45061979"  # InvaCost_database_v4.1.xlsx
COST = "Cost_estimate_per_year_2017_USD_exchange_rate"
OUT = os.path.dirname(os.path.abspath(__file__))
FIG = os.path.join(OUT, "figures"); TAB = os.path.join(OUT, "tables")
os.makedirs(FIG, exist_ok=True); os.makedirs(TAB, exist_ok=True)
CB = ["#0072B2", "#E69F00"]


def fetch():
    # resolve the exact file url from the Figshare API (robust to file-id changes), verify md5
    a = json.load(urllib.request.urlopen("https://api.figshare.com/v2/articles/12668570", timeout=60))
    f = [x for x in a["files"] if x["name"] == "InvaCost_database_v4.1.xlsx"][0]
    raw = urllib.request.urlopen(f["download_url"], timeout=180).read()
    md5 = hashlib.md5(raw).hexdigest()
    if md5 != EXPECTED_MD5:
        raise SystemExit(f"md5 mismatch: {md5}")
    return raw, md5


def rank_biserial(u, n1, n2):
    return 1.0 - (2.0 * u) / (n1 * n2)


def boot_ratio_ci(a, b, n=3000, seed=0):
    """Bootstrap 95% CI for the ratio of group medians (raw cost)."""
    rng = np.random.default_rng(seed)
    rs = [np.median(rng.choice(a, len(a))) / np.median(rng.choice(b, len(b))) for _ in range(n)]
    return [float(np.percentile(rs, 2.5)), float(np.percentile(rs, 97.5))]


def bh_holm(p):
    p = np.asarray(p, float); m = len(p); o = np.argsort(p)
    bh = np.empty(m); prev = 1.0
    for k, i in enumerate(o[::-1]):
        prev = min(prev, p[i] * m / (m - k)); bh[i] = min(prev, 1.0)
    holm = np.empty(m); prev = 0.0
    for k, i in enumerate(o):
        prev = max(prev, (m - k) * p[i]); holm[i] = min(prev, 1.0)
    return bh.tolist(), holm.tolist()


def compare(df, col, g1, g2, label):
    a = df.loc[df[col] == g1, COST].to_numpy(float)
    b = df.loc[df[col] == g2, COST].to_numpy(float)
    la, lb = np.log10(a), np.log10(b)
    u, p = stats.mannwhitneyu(la, lb, alternative="two-sided")
    return {
        "test": label, "group1": g1, "group2": g2, "n1": int(len(a)), "n2": int(len(b)),
        "U": float(u), "p": float(p), "rank_biserial_r": float(rank_biserial(u, len(a), len(b))),
        "median1_usd": float(np.median(a)), "median2_usd": float(np.median(b)),
        "median_ratio_1_over_2": float(np.median(a) / np.median(b)),
        "median_ratio_ci95": boot_ratio_ci(a, b),
    }


def main():
    raw, md5 = fetch()
    df = pd.read_excel(io.BytesIO(raw), sheet_name=0)
    df = df[pd.to_numeric(df[COST], errors="coerce").notna()].copy()
    df[COST] = df[COST].astype(float)
    df = df[df[COST] > 0]

    results = {"dataset": {"figshare_id": "12668570", "md5": md5, "n_cost_records": int(len(df)),
                           "response": COST, "scale": "log10", "note": "recorded standardised costs; InvaCost has known reporting/geographic biases"}}
    results["T1_observed_vs_potential"] = compare(df, "Implementation", "Observed", "Potential", "Observed vs Potential")
    env = df[df["Environment"].isin(["Terrestrial", "Aquatic"])]
    results["T2_terrestrial_vs_aquatic"] = compare(env, "Environment", "Terrestrial", "Aquatic", "Terrestrial vs Aquatic")
    ct = df[df["Type_of_cost_merged"].isin(["Damage", "Management"])]
    results["T3_damage_vs_management"] = compare(ct, "Type_of_cost_merged", "Damage", "Management", "Damage vs Management")

    fam = [("T1", results["T1_observed_vs_potential"]["p"]), ("T2", results["T2_terrestrial_vs_aquatic"]["p"]),
           ("T3", results["T3_damage_vs_management"]["p"])]
    bh, holm = bh_holm([p for _, p in fam])
    results["multiplicity"] = {n: {"raw_p": p, "bh_p": bh[i], "holm_p": holm[i]} for i, (n, p) in enumerate(fam)}
    with open(os.path.join(OUT, "results.json"), "w") as fh:
        json.dump(results, fh, indent=2)

    def box(df2, col, groups, title, fname, res):
        fig, ax = plt.subplots(figsize=(5.2, 4.2), dpi=150)
        data = [np.log10(df2.loc[df2[col] == g, COST].to_numpy(float)) for g in groups]
        ax.boxplot(data, tick_labels=[f"{g}\n(n={len(x):,})" for g, x in zip(groups, data)], showfliers=False)
        for i, x in enumerate(data):
            xs = np.full(len(x), i + 1) + np.random.uniform(-.06, .06, len(x))
            ax.scatter(xs, x, s=3, alpha=0.15, color=CB[i % 2])
        ax.set_ylabel("log10 annual cost (2017 USD)")
        ax.set_title(f"{title}\nMann-Whitney p={res['p']:.1e}, median ratio {res['median_ratio_1_over_2']:.1f}x")
        fig.tight_layout(); fig.savefig(os.path.join(FIG, fname)); plt.close(fig)

    box(df, "Implementation", ["Observed", "Potential"], "Recorded cost: observed vs potential", "fig-1-observed-vs-potential.png", results["T1_observed_vs_potential"])
    box(env, "Environment", ["Terrestrial", "Aquatic"], "Recorded cost by environment", "fig-2-terrestrial-vs-aquatic.png", results["T2_terrestrial_vs_aquatic"])
    box(ct, "Type_of_cost_merged", ["Damage", "Management"], "Recorded cost: damage vs management", "fig-3-damage-vs-management.png", results["T3_damage_vs_management"])

    rows = [results[k] for k in ("T1_observed_vs_potential", "T2_terrestrial_vs_aquatic", "T3_damage_vs_management")]
    pd.DataFrame([{"test": r["test"], "n1": r["n1"], "n2": r["n2"], "median1_usd": r["median1_usd"],
                   "median2_usd": r["median2_usd"], "median_ratio": r["median_ratio_1_over_2"],
                   "ratio_ci_low": r["median_ratio_ci95"][0], "ratio_ci_high": r["median_ratio_ci95"][1],
                   "rank_biserial_r": r["rank_biserial_r"], "p": r["p"]} for r in rows]).to_csv(os.path.join(TAB, "tbl-1-group-comparisons.csv"), index=False)
    df.groupby("Implementation")[COST].agg(["count", "median", "mean"]).to_csv(os.path.join(TAB, "tbl-2-cost-by-implementation.csv"))
    pd.DataFrame([{"test": n, **results["multiplicity"][n]} for n, _ in fam]).to_csv(os.path.join(TAB, "tbl-3-multiplicity.csv"), index=False)

    dp = {"name": "invacost-tables", "resources": []}
    for path in ["tbl-1-group-comparisons.csv", "tbl-2-cost-by-implementation.csv", "tbl-3-multiplicity.csv"]:
        cols = pd.read_csv(os.path.join(TAB, path)).columns
        dp["resources"].append({"name": path[:-4], "path": path, "schema": {"fields": [
            {"name": c, "type": "string" if c in ("test", "Implementation") else "number"} for c in cols]}})
    with open(os.path.join(TAB, "datapackage.json"), "w") as fh:
        json.dump(dp, fh, indent=2)

    print("md5", md5, "| n", len(df))
    for k in ("T1_observed_vs_potential", "T2_terrestrial_vs_aquatic", "T3_damage_vs_management"):
        r = results[k]; print(k, "p=%.2e ratio=%.1fx CI%s r=%.2f n=%d/%d" % (r["p"], r["median_ratio_1_over_2"], r["median_ratio_ci95"], r["rank_biserial_r"], r["n1"], r["n2"]))
    print("BH", [round(x, 4) for x in bh])


if __name__ == "__main__":
    main()
