"""
Arm B, round 2 (open question), automated executor.

Self-framed question: how functionally divergent are the single-cell Bodo
genomes from each other and from the B. saltans reference at the protein-domain
(PFAM) level, and do the single cells form a group distinct from the reference?

Source: Figshare 10.6084/m9.figshare.31362613 (CC-BY), Supplementary File 9,
sheet 'PFAMs counts' (PFAM domain counts for Bsal reference + 7 single cells).
Deterministic; classical MDS implemented in numpy (no sklearn dependency).
"""
import hashlib, json, os, urllib.request
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from scipy.spatial.distance import pdist, squareform
from scipy.cluster.hierarchy import linkage, dendrogram
from scipy.stats import mannwhitneyu

HERE = os.path.dirname(os.path.abspath(__file__))
RAW = os.path.join(os.path.dirname(HERE), "raw")   # reuse shared raw dir
FIG = os.path.join(HERE, "figures"); TAB = os.path.join(HERE, "tables")
for d in (RAW, FIG, TAB):
    os.makedirs(d, exist_ok=True)

FID, MD5 = "61984570", "660a3153c8503ab9a7dfbe280781d532"   # Supplementary File 9.xlsx

def fetch():
    path = os.path.join(RAW, "file9.xlsx")
    if not os.path.exists(path):
        urllib.request.urlretrieve(f"https://ndownloader.figshare.com/files/{FID}", path)
    assert hashlib.md5(open(path, "rb").read()).hexdigest() == MD5, "md5 mismatch"
    return path

def classical_mds(D, k=2):
    n = D.shape[0]
    J = np.eye(n) - np.ones((n, n)) / n
    B = -0.5 * J @ (D ** 2) @ J
    w, V = np.linalg.eigh(B)
    idx = np.argsort(w)[::-1]
    w, V = w[idx][:k], V[:, idx][:, :k]
    return V * np.sqrt(np.clip(w, 0, None))

def main():
    df = pd.read_excel(fetch(), sheet_name="PFAMs counts")
    freq_cols = [c for c in df.columns if str(c).startswith("Freq_")]
    labels = [c.replace("Freq_", "") for c in freq_cols]
    counts = df[freq_cols].apply(pd.to_numeric, errors="coerce").fillna(0).to_numpy(float)  # PFAMs x genomes
    rel = counts / counts.sum(axis=0, keepdims=True)         # normalise per genome (control for size)
    X = rel.T                                                # genomes x PFAMs

    D = squareform(pdist(X, metric="braycurtis"))
    n = len(labels)
    bsal_i = labels.index("Bsal")
    cell_i = [i for i in range(n) if i != bsal_i]

    within_cells = [D[i, j] for a, i in enumerate(cell_i) for j in cell_i[a + 1:]]
    cell_to_bsal = [D[bsal_i, i] for i in cell_i]
    U, p = mannwhitneyu(cell_to_bsal, within_cells, alternative="greater")

    mean_pair = float(np.mean(D[np.triu_indices(n, 1)]))
    mean_to_others = {labels[i]: float(np.mean([D[i, j] for j in range(n) if j != i])) for i in range(n)}
    most_divergent = max(mean_to_others, key=mean_to_others.get)

    # robustness: cosine distance
    Dc = squareform(pdist(X, metric="cosine"))
    within_c = [Dc[i, j] for a, i in enumerate(cell_i) for j in cell_i[a + 1:]]
    to_bsal_c = [Dc[bsal_i, i] for i in cell_i]
    _, p_cos = mannwhitneyu(to_bsal_c, within_c, alternative="greater")

    # tbl-1: distance matrix
    dm = pd.DataFrame(D, index=labels, columns=labels).round(4)
    dm.to_csv(os.path.join(TAB, "tbl-1-braycurtis-distance.csv"))

    # fig-1: dendrogram (average linkage on Bray-Curtis)
    plt.figure(figsize=(6, 4))
    Z = linkage(pdist(X, metric="braycurtis"), method="average")
    dendrogram(Z, labels=labels, leaf_rotation=0, color_threshold=0)
    plt.ylabel("Bray-Curtis distance (PFAM profiles)")
    plt.title("Functional clustering of Bodo genomes")
    plt.tight_layout(); plt.savefig(os.path.join(FIG, "fig-1-dendrogram.png"), dpi=150); plt.close()

    # fig-2: PCoA ordination
    coords = classical_mds(D, 2)
    plt.figure(figsize=(5.6, 4.6))
    for i, lab in enumerate(labels):
        col = "#D85A30" if lab == "Bsal" else "#534AB7"
        plt.scatter(coords[i, 0], coords[i, 1], s=70, color=col,
                    edgecolor="white", zorder=3)
        plt.annotate(lab, (coords[i, 0], coords[i, 1]), fontsize=9,
                     xytext=(5, 4), textcoords="offset points")
    plt.xlabel("PCoA 1"); plt.ylabel("PCoA 2")
    plt.title("PFAM-profile ordination (orange = B. saltans reference)")
    plt.tight_layout(); plt.savefig(os.path.join(FIG, "fig-2-pcoa.png"), dpi=150); plt.close()

    results = {
        "dataset_doi": "10.6084/m9.figshare.31362613",
        "question": "functional (PFAM) divergence among single-cell Bodo genomes vs the B. saltans reference",
        "n_genomes": n, "n_pfams": int(counts.shape[0]), "labels": labels,
        "mean_pairwise_braycurtis": round(mean_pair, 4),
        "mean_within_single_cells": round(float(np.mean(within_cells)), 4),
        "mean_single_cell_to_reference": round(float(np.mean(cell_to_bsal)), 4),
        "mannwhitney_cell_to_ref_greater_than_within": {"U": float(U), "p_value": float(p)},
        "robustness_cosine_p": float(p_cos),
        "most_divergent_genome": most_divergent,
        "mean_distance_to_others": {k: round(v, 4) for k, v in mean_to_others.items()},
        "figures": ["figures/fig-1-dendrogram.png", "figures/fig-2-pcoa.png"],
        "table": "tables/tbl-1-braycurtis-distance.csv",
    }
    json.dump(results, open(os.path.join(HERE, "results2.json"), "w"), indent=2)
    print(json.dumps(results, indent=2))


if __name__ == "__main__":
    main()
