"""
AGN Contribution to Composite SEDs
===================================

Active galactic nuclei (AGN) can dominate galaxy SEDs across UV to IR wavelengths.
This script sweeps the AGN fraction f_AGN ∈ {0.0, 0.1, 0.3, 0.5, 0.8, 1.0}
blending a star-forming galaxy stellar continuum with an AGN accretion disc
spectrum. Shows the transition from star-formation-dominated to AGN-dominated
SED morphology.

.. sphx-glr-precomputed-img:

.. image:: images/sphx_glr_plot_panchromatic_agn_fraction_001.png
   :alt: plot_panchromatic_agn_fraction
   :class: sphx-glr-single-img

"""

# sphinx_gallery_thumbnail_number = 1

from pathlib import Path

import jax
import jax.numpy as jnp
import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np

jax.config.update("jax_enable_x64", True)

from tengri import (
    Fixed,
    Observation,
    Parameters,
    SEDModel,
    Spectroscopy,
    load_ssp_data,
    setup_style,
)
from tengri.components.agn import qsogen

setup_style()


def _find_ssp():
    """Find SSP data file in standard locations."""
    name = "ssp_prsc_miles_chabrier_wNE_logGasU-3.0_logGasZ0.0.h5"
    for p in [
        Path("data") / name,
        Path("../data") / name,
        Path("../../data") / name,
        Path("../../../data") / name,
    ]:
        if p.exists():
            return str(p)
    return None


ssp_path = _find_ssp()
if ssp_path is None:
    raise FileNotFoundError("SSP data not found — skipping example")

ssp = load_ssp_data(ssp_path)

# Wavelength grid: UV through near-IR
wave_sed = jnp.logspace(jnp.log10(1000.0), jnp.log10(1e6), 600)  # 0.1 µm – 100 µm [Å]
obs = Observation(spectroscopy=Spectroscopy(wave_obs=wave_sed))

# Base stellar SED (star-forming)
spec = Parameters(
    mean_sfh_type="tsnorm",
    dust_emission="draine_li2007",
    sfh_tsnorm_log_peak_sfr=Fixed(1.0),  # peak SFR ~ 10 Msun/yr
    sfh_tsnorm_peak_lbt_gyr=Fixed(2.0),
    sfh_tsnorm_width_gyr=Fixed(1.5),
    sfh_tsnorm_skew=Fixed(0.0),
    sfh_tsnorm_trunc=Fixed(2.0),
    met_logzsol=Fixed(0.0),
    dust_tau_bc=Fixed(0.3),
    dust_tau_diff=Fixed(0.2),
    dust_slope=Fixed(-0.7),
    dust_umin=Fixed(2.0),
    dust_qpah=Fixed(3.5),
    dust_gamma_dl=Fixed(0.02),
    redshift=Fixed(0.05),
)

model = SEDModel(spec, ssp, observation=obs)
key = jax.random.PRNGKey(42)
params = spec.sample(key)
pred_stellar = model.predict_rest_sed(params)

wave_rest_um = np.array(pred_stellar.wavelength) / 1e4
sed_stellar = np.array(pred_stellar.sed)
wave_rest_aa = np.array(pred_stellar.wavelength)  # Keep in Angstrom for AGN function

# AGN SED (QSOgen disc at moderate bolometric luminosity)
# Evaluate on the same wavelength grid as the stellar SED
log_lbol_agn = 11.0  # L_bol ~ 10^11 L_sun
wave_sed_agn = jnp.array(wave_rest_aa)  # Use stellar wavelength grid
sed_agn = np.array(qsogen(wave_sed_agn, agn_log_lbol=log_lbol_agn))

# AGN fractions to sweep
agn_fracs = [0.0, 0.1, 0.3, 0.5, 0.8, 1.0]
cmap = plt.cm.viridis
colors = [cmap(i / (len(agn_fracs) - 1)) for i in range(len(agn_fracs))]

fig, ax = plt.subplots(figsize=(10, 6))

for agn_frac, color in zip(agn_fracs, colors):
    # Blend stellar + AGN: L_nu = (1 - f_AGN) * L_stellar + f_AGN * L_AGN
    # Normalize AGN to match stellar luminosity scale
    agn_peak = np.nanmax(sed_agn[sed_agn > 0])
    stellar_peak = np.nanmax(sed_stellar[sed_stellar > 0])
    sed_agn_normalized = sed_agn * (stellar_peak / agn_peak)

    sed_composite = (1.0 - agn_frac) * sed_stellar + agn_frac * sed_agn_normalized

    mask = sed_composite > 0
    ax.loglog(
        wave_rest_um[mask],
        sed_composite[mask],
        color=color,
        lw=2.0,
        label=rf"$f_{{\mathrm{{AGN}}}} = {agn_frac}$",
    )

ax.set_xlabel(r"Wavelength [$\mu$m]", fontsize=12)
ax.set_ylabel(r"$L_\nu$ [erg s$^{-1}$ Hz$^{-1}$]", fontsize=12)
ax.set_title("Galaxy + AGN Composite: AGN Fraction Sweep", fontsize=14)
ax.legend(fontsize=11, frameon=False, loc="upper right")
ax.grid(True, alpha=0.3, which="both")
ax.set_xlim(0.08, 1e2)
ax.set_ylim(1e22, 1e34)

fig.tight_layout()
# Save to script directory
script_dir = Path(__file__).resolve().parent if "__file__" in dir() else Path(".")
plt.savefig(str(script_dir / "plot_panchromatic_agn_fraction.png"), dpi=150, bbox_inches="tight")
plt.close()
