#!/usr/bin/env python3
"""
Pre-registered secondary analysis: preferred-gender congruence and facial attractiveness ratings.

Reads the two attached CSVs (asserts md5), reshapes to long form, defines preference
congruence, runs exactly the three pre-registered tests (T1, T2, T3), applies a
Holm-Bonferroni correction across the family, and writes every figure, table, and results.json.

Self-contained and deterministic. Re-running reproduces every number.

Data: Face Research Lab London Set (DeBruine & Jones), figshare 10.6084/m9.figshare.5047666, CC-BY-4.0.

NOTE ON RUNTIME: the crossed linear mixed models are fit on ~225k-256k observations with
~2,600 crossed random-effect levels. Each fit takes several minutes and up to ~40 GB RAM;
the full script runs on the order of 1-2 hours on a multi-core machine. A fixed-effects
within-estimator cross-check (fast, optimiser-independent) validates every fixed-effect estimate.
"""

import os
import json
import hashlib
import warnings

import numpy as np
import pandas as pd
from scipy import stats
import statsmodels.formula.api as smf
from statsmodels.stats.multitest import multipletests

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

warnings.filterwarnings("ignore")

# --------------------------------------------------------------------------------------
# 0. Configuration and reproducibility
# --------------------------------------------------------------------------------------
SEED = 42
np.random.seed(SEED)

RATINGS_CSV = "london_faces_ratings.csv"
INFO_CSV = "london_faces_info.csv"
MD5_RATINGS = "fe3743a1bb6b2b414e5170b8723aec1b"
MD5_INFO = "52e9d142812130f0b19af168260fb160"

HERE = os.path.dirname(os.path.abspath(__file__))
FIG_DIR = os.path.join(HERE, "figures")
TBL_DIR = os.path.join(HERE, "tables")
os.makedirs(FIG_DIR, exist_ok=True)
os.makedirs(TBL_DIR, exist_ok=True)

# Colourblind-safe palette (Wong 2011 / Okabe-Ito)
BLUE, ORANGE, GREY = "#0072B2", "#E69F00", "#7A7A7A"

plt.rcParams.update({
    "figure.dpi": 150, "savefig.dpi": 200, "font.size": 8,
    "axes.titlesize": 9, "axes.labelsize": 8, "xtick.labelsize": 7,
    "ytick.labelsize": 7, "axes.spines.top": False, "axes.spines.right": False,
    "font.family": "sans-serif",
})


def md5_of(path):
    with open(path, "rb") as f:
        return hashlib.md5(f.read()).hexdigest()


# --------------------------------------------------------------------------------------
# 1. Load, verify integrity, profile
# --------------------------------------------------------------------------------------
def load_and_verify():
    m_r, m_i = md5_of(RATINGS_CSV), md5_of(INFO_CSV)
    assert m_r == MD5_RATINGS, f"ratings md5 mismatch: {m_r} != {MD5_RATINGS}"
    assert m_i == MD5_INFO, f"info md5 mismatch: {m_i} != {MD5_INFO}"
    ratings = pd.read_csv(RATINGS_CSV)
    info = pd.read_csv(INFO_CSV)
    # profile
    profile = {
        "md5": {RATINGS_CSV: m_r, INFO_CSV: m_i},
        "n_raters_total": int(len(ratings)),
        "n_faces": int(len(info)),
        "n_face_columns": int(sum(c.startswith("X") for c in ratings.columns)),
        "rater_sex": ratings["rater_sex"].value_counts(dropna=False).to_dict(),
        "rater_sexpref": ratings["rater_sexpref"].value_counts(dropna=False).to_dict(),
        "rater_age_mean": float(ratings["rater_age"].mean()),
        "face_gender": info["face_gender"].value_counts(dropna=False).to_dict(),
    }
    return ratings, info, profile


# --------------------------------------------------------------------------------------
# 2. Reshape to long form and join face gender
# --------------------------------------------------------------------------------------
def to_long(ratings, info):
    face_cols = [c for c in ratings.columns if c.startswith("X")]
    r2 = ratings.reset_index().rename(columns={"index": "rater_id"})
    long = r2.melt(
        id_vars=["rater_id", "rater_sex", "rater_sexpref", "rater_age"],
        value_vars=face_cols, var_name="face_col", value_name="rating",
    )
    # The integer in X0NN IS the face_id.
    long["face_id"] = long["face_col"].str.lstrip("X").astype(int)
    long = long.dropna(subset=["rating"]).copy()
    long = long.merge(
        info[["face_id", "face_gender", "face_age", "face_eth"]],
        on="face_id", how="left",
    )
    assert long["face_gender"].isna().sum() == 0, "every rating must map to a known face"
    assert set(long["face_id"]) == set(info["face_id"]), "face id coverage mismatch"
    return long


# --------------------------------------------------------------------------------------
# 3. Congruence mapping (EXPLICIT)
# --------------------------------------------------------------------------------------
# Stated preference -> set of preferred face genders:
#   men     -> {male}          (prefers men)
#   women   -> {female}        (prefers women)
#   either  -> {male, female}  (bisexual/either: EVERY face is a preferred gender -> all congruent)
#   neither -> {}              (asexual/neither: NO face is a preferred gender -> all incongruent)
#   NaN     -> undefined       (preference not recorded -> EXCLUDED from congruence analyses)
# A rating is CONGRUENT (1) if the face's gender is in the rater's preferred set, else INCONGRUENT (0).
PREF_MAP = {"men": {"male"}, "women": {"female"},
            "either": {"male", "female"}, "neither": set()}


def add_congruence(long):
    def congr(row):
        pref = PREF_MAP.get(row["rater_sexpref"], None)
        if pref is None:            # NaN / unrecorded preference
            return np.nan
        return 1.0 if row["face_gender"] in pref else 0.0

    # Sex-based expected preference (heterosexual assumption) for T3b:
    #   male rater   -> female faces congruent; female rater -> male faces congruent.
    #   intersex / missing rater_sex -> undefined (excluded from T3b).
    def congr_sex(row):
        rs = row["rater_sex"]
        if rs == "male":
            return 1.0 if row["face_gender"] == "female" else 0.0
        if rs == "female":
            return 1.0 if row["face_gender"] == "male" else 0.0
        return np.nan

    long = long.copy()
    long["congruent"] = long.apply(congr, axis=1)
    long["congruent_sex"] = long.apply(congr_sex, axis=1)
    return long


# --------------------------------------------------------------------------------------
# 4. Crossed linear mixed model helper (crossed random intercepts for rater and face)
# --------------------------------------------------------------------------------------
def fit_crossed(df, formula):
    """rating ~ <formula RHS> + (1|rater) + (1|face_id) via a single dummy group + variance components."""
    df = df.reset_index(drop=True).copy()
    df["grp"] = 1
    df["rater_f"] = df["rater_id"].astype(str)
    df["face_f"] = df["face_id"].astype(str)
    vc = {"rater": "0 + C(rater_f)", "face": "0 + C(face_f)"}
    model = smf.mixedlm(formula, df, groups="grp", vc_formula=vc)
    return model.fit(method="lbfgs", maxiter=200)


def coef_record(res, name):
    ci = res.conf_int()
    return {
        "coef": float(res.params[name]),
        "se": float(res.bse[name]) if name in res.bse.index else None,
        "ci_low": float(ci.loc[name, 0]), "ci_high": float(ci.loc[name, 1]),
        "p": float(res.pvalues[name]) if name in res.pvalues.index else None,
    }


def within_fe_slope(df, xcols, ycol="rating", n_iter=60):
    """Two-way (rater + face) within fixed-effects slope(s). Optimiser-independent benchmark
    for the crossed-LMM fixed effects: removes additive rater and face means, then OLS."""
    df = df.reset_index(drop=True)
    ridx = df["rater_id"].values
    fidx = df["face_id"].values
    cols = [ycol] + xcols
    M = {c: df[c].astype(float).values.copy() for c in cols}
    for _ in range(n_iter):
        for idx in (ridx, fidx):
            for c in cols:
                M[c] = M[c] - pd.Series(M[c]).groupby(idx).transform("mean").values
    X = np.column_stack([M[c] for c in xcols])
    y = M[ycol]
    b = np.linalg.lstsq(X, y, rcond=None)[0]
    return dict(zip(xcols, [float(v) for v in b]))


# --------------------------------------------------------------------------------------
# MAIN
# --------------------------------------------------------------------------------------
def main():
    ratings, info, profile = load_and_verify()
    long = add_congruence(to_long(ratings, info))
    SD = float(long["rating"].std())

    # Analysis subsets
    d1 = long.dropna(subset=["congruent"]).copy()                      # stated-pref defined (T1)
    d_t2 = d1[d1["rater_sex"].isin(["male", "female"])].copy()         # T2 (drop intersex/NaN sex)
    d_t2["age_c"] = d_t2["rater_age"] - d_t2["rater_age"].mean()
    age_mean = float(d_t2["rater_age"].mean())
    d_t3a = d1[d1["rater_sexpref"].isin(["men", "women"])].copy()      # T3a single-gender pref
    d_t3b = long.dropna(subset=["congruent_sex"]).copy()              # T3b sex-based

    # ---------------- T1: primary congruence ----------------
    res_t1 = fit_crossed(d1, "rating ~ congruent")
    t1 = coef_record(res_t1, "congruent")
    # unadjusted per-rater paired comparison (declared fallback / context)
    pr = d1.groupby(["rater_id", "congruent"])["rating"].mean().unstack().dropna()
    diff = pr[1.0] - pr[0.0]
    w_stat, w_p = stats.wilcoxon(pr[1.0], pr[0.0])
    # ---------------- T2: moderation by sex and age ----------------
    res_sex_f = fit_crossed(d_t2, "rating ~ congruent * C(rater_sex, Treatment('female'))")
    res_sex_m = fit_crossed(d_t2, "rating ~ congruent * C(rater_sex, Treatment('male'))")
    inter = coef_record(res_sex_f, "congruent:C(rater_sex, Treatment('female'))[T.male]")
    slope_female = coef_record(res_sex_f, "congruent")
    slope_male = coef_record(res_sex_m, "congruent")
    res_age = fit_crossed(d_t2, "rating ~ congruent * age_c")
    age_int = coef_record(res_age, "congruent:age_c")
    age_main = coef_record(res_age, "congruent")
    n_male = int(d_t2[d_t2.rater_sex == "male"].rater_id.nunique())
    n_female = int(d_t2[d_t2.rater_sex == "female"].rater_id.nunique())

    # ---------------- T3: robustness ----------------
    res_t3a = fit_crossed(d_t3a, "rating ~ congruent")
    t3a = coef_record(res_t3a, "congruent")
    res_t3b = fit_crossed(d_t3b, "rating ~ congruent_sex")
    t3b = coef_record(res_t3b, "congruent_sex")

    # ---------------- Estimator validation: two-way (rater+face) within fixed-effects
    # benchmark for EVERY reported fixed-effect estimate (optimiser-independent). Each
    # crossed-LMM fixed effect below is compared against the within-FE slope of the same term.
    fe_bench = {}
    fe_bench["T1_congruence"] = within_fe_slope(d1, ["congruent"])["congruent"]
    # T2 sex interaction: congruent + congruent:male; male main effect absorbed by rater demeaning.
    d_sex = d_t2.copy()
    d_sex["male"] = (d_sex["rater_sex"] == "male").astype(float)
    d_sex["cong_male"] = d_sex["congruent"] * d_sex["male"]
    fe_sex = within_fe_slope(d_sex, ["congruent", "cong_male"])
    fe_bench["T2_sex_interaction"] = fe_sex["cong_male"]      # matches inter (male vs female add)
    fe_bench["T2_sex_slope_female"] = fe_sex["congruent"]     # matches slope_female
    fe_bench["T2_sex_slope_male"] = fe_sex["congruent"] + fe_sex["cong_male"]  # matches slope_male
    # T2 age interaction: congruent + age_c + congruent:age_c.
    d_age = d_t2.copy()
    d_age["cong_age"] = d_age["congruent"] * d_age["age_c"]
    fe_age = within_fe_slope(d_age, ["congruent", "age_c", "cong_age"])
    fe_bench["T2_age_interaction"] = fe_age["cong_age"]
    # T3a single-gender; T3b sex-based.
    fe_bench["T3a_single_gender"] = within_fe_slope(d_t3a, ["congruent"])["congruent"]
    fe_bench["T3b_sex_based"] = within_fe_slope(d_t3b, ["congruent_sex"])["congruent_sex"]

    # Mixed-model counterparts, paired to the benchmarks above.
    mixed_for_bench = {
        "T1_congruence": t1["coef"],
        "T2_sex_interaction": inter["coef"],
        "T2_sex_slope_female": slope_female["coef"],
        "T2_sex_slope_male": slope_male["coef"],
        "T2_age_interaction": age_int["coef"],
        "T3a_single_gender": t3a["coef"],
        "T3b_sex_based": t3b["coef"],
    }
    fe_agreement = {k: {"within_fe": float(fe_bench[k]), "mixed": float(mixed_for_bench[k]),
                        "abs_diff": float(abs(fe_bench[k] - mixed_for_bench[k]))}
                    for k in mixed_for_bench}
    max_abs_diff = max(v["abs_diff"] for v in fe_agreement.values())

    # ---------------- Multiplicity correction (Holm over the pre-registered family) --------
    family = {
        "T1_congruence": t1["p"],
        "T2_sex_interaction": inter["p"],
        "T2_age_interaction": age_int["p"],
        "T3a_single_gender": t3a["p"],
        "T3b_sex_based": t3b["p"],
    }
    keys = list(family)
    reject, p_holm, _, _ = multipletests([family[k] for k in keys], alpha=0.05, method="holm")
    holm = {k: {"p_raw": float(family[k]), "p_holm": float(ph), "reject_h0": bool(rj)}
            for k, ph, rj in zip(keys, p_holm, reject)}

    # ---------------- results.json ----------------
    results = {
        "meta": {
            "question": ("Do raters give higher attractiveness ratings to faces of their preferred "
                         "gender (preference congruence), and does that hold once the crossed "
                         "repeated-measures structure is accounted for?"),
            "dataset": ("Face Research Lab London Set (DeBruine & Jones), "
                        "figshare 10.6084/m9.figshare.5047666, CC-BY-4.0"),
            "md5": profile["md5"],
            "n_raters_total": profile["n_raters_total"], "n_faces": profile["n_faces"],
            "n_observations_total": int(len(long)),
            "rating_scale": "1-7 attractiveness", "rating_sd": SD,
            "congruence_mapping": {"men": "prefers male", "women": "prefers female",
                                   "either": "prefers both (all congruent)",
                                   "neither": "prefers neither (all incongruent)",
                                   "NaN": "undefined (excluded)"},
            "seed": SEED, "multiplicity_method": "Holm-Bonferroni over 5 pre-registered tests",
            "primary_estimator": ("crossed linear mixed model rating ~ fixed + (1|rater) + (1|face_id), "
                                  "REML, statsmodels MixedLM"),
            "estimator_validation": {
                "method": ("two-way (rater+face) within fixed-effects slope, computed for every "
                           "reported fixed-effect estimate as an optimiser-independent benchmark"),
                "terms": fe_agreement,
                "max_abs_diff_points": max_abs_diff,
            },
        },
        "T1_congruence": {
            "model": "rating ~ congruent + (1|rater) + (1|face_id)",
            "n_obs": int(len(d1)), "n_raters": int(d1.rater_id.nunique()),
            "estimate_points": t1["coef"], "se": t1["se"], "ci95": [t1["ci_low"], t1["ci_high"]],
            "p_raw": t1["p"], "p_holm": holm["T1_congruence"]["p_holm"],
            "effect_size_std": t1["coef"] / SD,
            "var_rater": float(res_t1.vcomp[1]), "var_face": float(res_t1.vcomp[0]),
            "resid_var": float(res_t1.scale),
            "unadjusted_per_rater_mean_diff": float(diff.mean()),
            "unadjusted_wilcoxon_W": float(w_stat), "unadjusted_wilcoxon_p": float(w_p),
            "unadjusted_paired_n": int(len(diff)),
            "note": ("Adjusted effect is positive but very small; the naive per-rater difference is "
                     "negative (confounded by face composition) — a Simpson-type reversal."),
        },
        "T2_sex_interaction": {
            "model": "rating ~ congruent * rater_sex + (1|rater) + (1|face_id)",
            "n_obs": int(len(d_t2)), "n_raters": int(d_t2.rater_id.nunique()),
            "interaction_estimate_points": inter["coef"],
            "interaction_ci95": [inter["ci_low"], inter["ci_high"]],
            "interaction_p_raw": inter["p"], "interaction_p_holm": holm["T2_sex_interaction"]["p_holm"],
            "slope_female_ref": {"estimate": slope_female["coef"],
                                 "ci95": [slope_female["ci_low"], slope_female["ci_high"]],
                                 "p": slope_female["p"], "n_raters": n_female},
            "slope_male": {"estimate": slope_male["coef"],
                           "ci95": [slope_male["ci_low"], slope_male["ci_high"]],
                           "p": slope_male["p"], "n_raters": n_male},
            "note": ("Congruence effect is driven by male raters (+%.3f); for female raters it is "
                     "null (%.3f, 95%% CI includes 0)." % (slope_male["coef"], slope_female["coef"])),
        },
        "T2_age_interaction": {
            "model": "rating ~ congruent * age_centered + (1|rater) + (1|face_id)",
            "n_obs": int(len(d_t2)), "n_raters": int(d_t2.rater_id.nunique()),
            "age_mean_years": age_mean,
            "congruent_at_mean_age": age_main["coef"],
            "interaction_estimate_per_year": age_int["coef"],
            "interaction_ci95": [age_int["ci_low"], age_int["ci_high"]],
            "interaction_p_raw": age_int["p"], "interaction_p_holm": holm["T2_age_interaction"]["p_holm"],
            "note": "Congruence effect increases with rater age by ~%.4f points/year." % age_int["coef"],
        },
        "T3a_single_gender": {
            "model": "rating ~ congruent + (1|rater) + (1|face_id); raters with men/women preference only",
            "n_obs": int(len(d_t3a)), "n_raters": int(d_t3a.rater_id.nunique()),
            "estimate_points": t3a["coef"], "ci95": [t3a["ci_low"], t3a["ci_high"]],
            "p_raw": t3a["p"], "p_holm": holm["T3a_single_gender"]["p_holm"],
            "effect_size_std": t3a["coef"] / SD,
            "note": "Essentially identical to primary — robust to excluding either/neither raters.",
        },
        "T3b_sex_based": {
            "model": ("rating ~ congruent_sex + (1|rater) + (1|face_id); "
                      "congruence from rater_sex (heterosexual assumption)"),
            "n_obs": int(len(d_t3b)), "n_raters": int(d_t3b.rater_id.nunique()),
            "estimate_points": t3b["coef"], "ci95": [t3b["ci_low"], t3b["ci_high"]],
            "p_raw": t3b["p"], "p_holm": holm["T3b_sex_based"]["p_holm"],
            "effect_size_std": t3b["coef"] / SD,
            "note": ("Sign FLIPS to negative when congruence is defined by assumed-heterosexual "
                     "rater sex instead of stated preference — the effect is specific to stated preference."),
        },
    }
    with open(os.path.join(HERE, "results.json"), "w") as f:
        json.dump(results, f, indent=2)

    # ---------------- Tables ----------------
    write_tables(results, holm, profile, info, SD)

    # ---------------- Figures ----------------
    fig1(diff, t1)
    fig2(slope_male, slope_female, inter, age_main, age_int, age_mean, d_t2, n_male, n_female)
    fig3(t1, t3a, t3b)

    print("Done. T1 congruent = %.4f [%.4f, %.4f], p_holm=%.2e"
          % (t1["coef"], t1["ci_low"], t1["ci_high"], holm["T1_congruence"]["p_holm"]))
    return results


# --------------------------------------------------------------------------------------
# Tables
# --------------------------------------------------------------------------------------
def write_tables(res, holm, profile, info, SD):
    t1 = res["T1_congruence"]
    pd.DataFrame([
        {"quantity": "Adjusted congruence effect (crossed LMM)", "estimate": t1["estimate_points"],
         "ci_low": t1["ci95"][0], "ci_high": t1["ci95"][1], "p_raw": t1["p_raw"],
         "p_holm": t1["p_holm"], "effect_size_std": t1["effect_size_std"],
         "n_obs": t1["n_obs"], "n_raters": t1["n_raters"]},
        {"quantity": "Unadjusted per-rater mean difference (congruent - incongruent)",
         "estimate": t1["unadjusted_per_rater_mean_diff"], "ci_low": None, "ci_high": None,
         "p_raw": t1["unadjusted_wilcoxon_p"], "p_holm": None, "effect_size_std": None,
         "n_obs": t1["n_obs"], "n_raters": t1["unadjusted_paired_n"]},
    ]).to_csv(os.path.join(TBL_DIR, "tbl-1-t1-congruence.csv"), index=False)

    s = res["T2_sex_interaction"]; a = res["T2_age_interaction"]
    pd.DataFrame([
        {"term": "Congruence x rater sex (male vs female), interaction",
         "estimate": s["interaction_estimate_points"], "ci_low": s["interaction_ci95"][0],
         "ci_high": s["interaction_ci95"][1], "p_raw": s["interaction_p_raw"],
         "p_holm": s["interaction_p_holm"], "n_raters": s["n_raters"]},
        {"term": "Congruence slope, male raters", "estimate": s["slope_male"]["estimate"],
         "ci_low": s["slope_male"]["ci95"][0], "ci_high": s["slope_male"]["ci95"][1],
         "p_raw": s["slope_male"]["p"], "p_holm": None, "n_raters": s["slope_male"]["n_raters"]},
        {"term": "Congruence slope, female raters", "estimate": s["slope_female_ref"]["estimate"],
         "ci_low": s["slope_female_ref"]["ci95"][0], "ci_high": s["slope_female_ref"]["ci95"][1],
         "p_raw": s["slope_female_ref"]["p"], "p_holm": None, "n_raters": s["slope_female_ref"]["n_raters"]},
        {"term": "Congruence x rater age, interaction (per year)",
         "estimate": a["interaction_estimate_per_year"], "ci_low": a["interaction_ci95"][0],
         "ci_high": a["interaction_ci95"][1], "p_raw": a["interaction_p_raw"],
         "p_holm": a["interaction_p_holm"], "n_raters": a["n_raters"]},
    ]).to_csv(os.path.join(TBL_DIR, "tbl-2-t2-moderation.csv"), index=False)

    t3a = res["T3a_single_gender"]; t3b = res["T3b_sex_based"]
    pd.DataFrame([
        {"specification": "Primary (stated preference, all raters)", "estimate": t1["estimate_points"],
         "ci_low": t1["ci95"][0], "ci_high": t1["ci95"][1], "p_raw": t1["p_raw"],
         "p_holm": t1["p_holm"], "n_raters": t1["n_raters"]},
        {"specification": "T3a: single-gender preference (men/women only)",
         "estimate": t3a["estimate_points"], "ci_low": t3a["ci95"][0], "ci_high": t3a["ci95"][1],
         "p_raw": t3a["p_raw"], "p_holm": t3a["p_holm"], "n_raters": t3a["n_raters"]},
        {"specification": "T3b: sex-based expected preference (heterosexual assumption)",
         "estimate": t3b["estimate_points"], "ci_low": t3b["ci95"][0], "ci_high": t3b["ci95"][1],
         "p_raw": t3b["p_raw"], "p_holm": t3b["p_holm"], "n_raters": t3b["n_raters"]},
    ]).to_csv(os.path.join(TBL_DIR, "tbl-3-t3-robustness.csv"), index=False)

    rows = []
    for c in ["rater_sex", "rater_sexpref"]:
        for val, cnt in profile[c].items():
            rows.append({"variable": c, "category": str(val), "n": int(cnt)})
    rows.append({"variable": "rater_age", "category": "mean", "n": round(profile["rater_age_mean"], 1)})
    for val, cnt in profile["face_gender"].items():
        rows.append({"variable": "face_gender", "category": str(val), "n": int(cnt)})
    pd.DataFrame(rows).to_csv(os.path.join(TBL_DIR, "tbl-4-composition.csv"), index=False)


# --------------------------------------------------------------------------------------
# Figures
# --------------------------------------------------------------------------------------
def fig1(diff, t1):
    fig, (axA, axB) = plt.subplots(1, 2, figsize=(7.2, 3.4))
    axA.axvline(0, color=GREY, lw=1, ls="--")
    axA.hist(diff.values, bins=40, color=GREY, alpha=0.55, edgecolor="white", lw=0.3)
    axA.axvline(diff.mean(), color=ORANGE, lw=2)
    axA.annotate("mean = %.2f" % diff.mean(), xy=(diff.mean(), 0),
                 xytext=(diff.mean() - 0.05, axA.get_ylim()[1] * 0.9), ha="right",
                 color=ORANGE, fontsize=7)
    axA.set_xlabel("Per-rater mean rating:\ncongruent - incongruent (points)")
    axA.set_ylabel("Raters (n = %d)" % len(diff))
    axA.set_title("Unadjusted per-rater difference", loc="left")

    se_un = diff.std(ddof=1) / np.sqrt(len(diff)); tcrit = stats.t.ppf(0.975, len(diff) - 1)
    ests = [diff.mean(), t1["coef"]]
    lo = [diff.mean() - tcrit * se_un, t1["ci_low"]]; hi = [diff.mean() + tcrit * se_un, t1["ci_high"]]
    for y, e, l, h, c in zip([1, 0], ests, lo, hi, [GREY, BLUE]):
        axB.plot([l, h], [y, y], color=c, lw=2.2, solid_capstyle="round")
        axB.plot(e, y, "o", color=c, ms=8, zorder=5)
        axB.annotate("%+.3f\n[%.3f, %.3f]" % (e, l, h), xy=(e, y),
                     xytext=(0, 12 if y == 0 else -22), textcoords="offset points",
                     ha="center", fontsize=6.8, color=c)
    axB.axvline(0, color=GREY, lw=1, ls="--")
    axB.set_yticks([1, 0]); axB.set_yticklabels(["Unadjusted\n(per-rater)", "Crossed LMM\n(rater + face adj.)"])
    axB.set_ylim(-0.6, 1.6); axB.margins(x=0.18)
    axB.set_xlabel("Congruence effect on attractiveness rating (points)")
    axB.set_title("Effect reverses after adjustment", loc="left")
    fig.suptitle("T1 - Preferred-gender congruence and attractiveness ratings",
                 fontsize=9, y=1.02, x=0.01, ha="left")
    fig.tight_layout()
    fig.savefig(os.path.join(FIG_DIR, "fig-1-preference-congruence.png"), bbox_inches="tight")
    plt.close(fig)


def fig2(slope_male, slope_female, inter, age_main, age_int, age_mean, d_t2, n_male, n_female):
    fig, (axA, axB) = plt.subplots(1, 2, figsize=(7.4, 3.5))
    for y, est, lab, col in zip([1, 0],
                                [slope_male, slope_female],
                                ["Male raters\n(n=%d)" % n_male, "Female raters\n(n=%d)" % n_female],
                                [BLUE, ORANGE]):
        axA.plot([est["ci_low"], est["ci_high"]], [y, y], color=col, lw=2.4, solid_capstyle="round")
        axA.plot(est["coef"], y, "o", color=col, ms=8, zorder=5)
        axA.annotate("%+.3f [%.3f, %.3f]" % (est["coef"], est["ci_low"], est["ci_high"]),
                     xy=(est["coef"], y), xytext=(0, 11), textcoords="offset points",
                     ha="center", fontsize=6.8, color=col)
    axA.axvline(0, color=GREY, lw=1, ls="--")
    axA.set_yticks([1, 0]); axA.set_yticklabels(["Male raters\n(n=%d)" % n_male,
                                                 "Female raters\n(n=%d)" % n_female])
    axA.set_ylim(-0.6, 1.6); axA.margins(x=0.20)
    axA.set_xlabel("Congruence effect (rating points)")
    axA.set_title("Congruence effect by rater sex\n(interaction p = %.1e)" % inter["p"],
                  loc="left", fontsize=8)

    ages = np.linspace(d_t2.rater_age.quantile(0.02), d_t2.rater_age.quantile(0.98), 100)
    slope_at = age_main["coef"] + age_int["coef"] * (ages - age_mean)
    axB.axhline(0, color=GREY, lw=1, ls="--")
    axB.plot(ages, slope_at, color=BLUE, lw=2.2)
    axB.annotate("slope = %+.4f / yr\n[%.4f, %.4f]\np = %.1e"
                 % (age_int["coef"], age_int["ci_low"], age_int["ci_high"], age_int["p"]),
                 xy=(ages[70], slope_at[70]), xytext=(ages.min() + 2, slope_at.max() * 0.55),
                 fontsize=6.8, color=BLUE)
    axB.set_xlabel("Rater age (years)"); axB.set_ylabel("Congruence effect (rating points)")
    axB.set_title("Congruence effect increases with age", loc="left", fontsize=8); axB.margins(x=0.03)
    fig.suptitle("T2 - Moderation of the congruence effect by rater sex and age",
                 fontsize=9, y=1.03, x=0.01, ha="left")
    fig.tight_layout(); fig.subplots_adjust(wspace=0.35)
    fig.savefig(os.path.join(FIG_DIR, "fig-2-moderation-sex-age.png"), bbox_inches="tight")
    plt.close(fig)


def fig3(t1, t3a, t3b):
    specs = [
        ("Primary: stated preference\n(all raters)", t1["coef"], t1["ci_low"], t1["ci_high"], BLUE, t1["n_raters"]),
        ("T3a: single-gender preference\n(men/women only)", t3a["coef"], t3a["ci_low"], t3a["ci_high"], BLUE, t3a["n_raters"]),
        ("T3b: sex-based expected preference\n(heterosexual assumption)", t3b["coef"], t3b["ci_low"], t3b["ci_high"], ORANGE, t3b["n_raters"]),
    ]
    fig, ax = plt.subplots(figsize=(6.8, 3.3))
    yp = list(range(len(specs) - 1, -1, -1))
    for y, (lab, c, l, h, col, nr) in zip(yp, specs):
        ax.plot([l, h], [y, y], color=col, lw=2.6, solid_capstyle="round")
        ax.plot(c, y, "o", color=col, ms=9, zorder=5)
        ax.annotate("%+.3f [%.3f, %.3f]  (n=%d)" % (c, l, h, nr), xy=(c, y),
                    xytext=(0, 12), textcoords="offset points", ha="center", fontsize=6.8, color=col)
    ax.axvline(0, color=GREY, lw=1.2, ls="--")
    ax.set_yticks(yp); ax.set_yticklabels([s[0] for s in specs], fontsize=7.5)
    ax.set_ylim(-0.8, len(specs) - 0.2); ax.set_xlim(-0.06, 0.075)
    ax.set_xlabel("Congruence effect on attractiveness rating (points)")
    ax.set_title("T3 - Robustness of the congruence effect across definitions", loc="left", fontsize=9)
    ax.annotate("sign flips under sex-based definition", xy=(specs[2][1], yp[2] - 0.28),
                ha="center", fontsize=6.5, color=ORANGE, style="italic")
    fig.tight_layout()
    fig.savefig(os.path.join(FIG_DIR, "fig-3-robustness.png"), bbox_inches="tight")
    plt.close(fig)


if __name__ == "__main__":
    main()
