from PySide6.QtGui import QCursor, QFont
from PySide6.QtCore import Qt, QThread, Signal
from PySide6.QtWidgets import QApplication, QFrame, QWidget, QSlider, QHBoxLayout, QVBoxLayout, QLabel, QDoubleSpinBox, QSizePolicy
import sys
from ..basemodule import *
import pyqtgraph as pg
import numpy as np
from astropy.table import Table
from astropy.io import fits
from astropy.stats import sigma_clip
import pandas as pd
from importlib.resources import files
from importlib.metadata import PackageNotFoundError, version as dist_version
from pathlib import Path
import os
# locate the data file in the package
data_path = Path(files("specbox").joinpath("data/templates"))
fits_file = data_path / "qso1" / "optical_nir_qso_template_v1.fits"
tb_temp = Table.read(str(fits_file))
tb_temp.rename_columns(['wavelength', 'flux'], ['Wave', 'Flux'])
try:
viewer_version = dist_version("specbox")
except PackageNotFoundError:
viewer_version = "0.0.0"
[docs]
class PGSpecPlot(pg.PlotWidget):
"""Interactive spectrum viewer based on :mod:`pyqtgraph`.
This widget was originally developed for Euclid SIR 1D spectra but has
been generalized to handle any spectrum class derived from
:class:`~specbox.basemodule.ConvenientSpecMixin` (e.g. ``SpecEuclid1d``,
``SpecLAMOST``).
"""
coordinate_changed = Signal(float, float) # Signal for coordinate updates
def __init__(self, spectra, SpecClass=SpecEuclid1d, initial_counter=0,
z_max=5.0, history_dict=None):
super().__init__()
self.SpecClass = SpecClass
# ``spectra`` can either be a FITS file containing multiple extensions
# or a list of individual spectrum files.
if isinstance(spectra, (list, tuple)):
self.speclist = list(spectra)
self.specfile = None
self.len_list = len(self.speclist)
else:
self.specfile = spectra
if hasattr(SpecClass, "count_in_file"):
self.len_list = int(SpecClass.count_in_file(spectra))
else:
with fits.open(spectra) as hdul:
self.len_list = len(hdul) - 1
self.speclist = None
if initial_counter >= self.len_list:
print("No more spectra to plot.\n\t Plotting the first spectrum.")
initial_counter = 0
self.history = history_dict if history_dict is not None else {}
self.setWindowTitle("Spectrum")
self.resize(1200, 800)
self.setBackground('w')
self.showGrid(x=True, y=True)
self.setMouseEnabled(x=True, y=True)
self.setLogMode(x=False, y=False)
self.setAspectLocked(False)
left_axis = self.getAxis('left')
left_axis.enableAutoSIPrefix(False)
self.enableAutoRange()
self.vb = self.getViewBox()
self.vb.setMouseMode(self.vb.RectMode)
self.z_min = 0.0
self.z_max = z_max
self.base_z_step = 0.001
self.slider_min = 0
self.slider_max = int((1/self.base_z_step) * np.log((1+self.z_max)/(1+self.z_min)))
self.message = ''
self.counter = initial_counter
# Create slider and spin box for redshift control
self.slider = QSlider(Qt.Horizontal)
self.slider.setMinimum(self.slider_min)
self.slider.setMaximum(self.slider_max)
self.slider.setTickPosition(QSlider.TicksBelow)
self.slider.setTickInterval(max(1, int((self.slider_max - self.slider_min) / 10)))
self.slider.valueChanged.connect(self.slider_changed)
# Create spectrum info title box
self.spectrum_info_label = QLabel()
self.spectrum_info_label.setFont(QFont("Arial", 14))
self.spectrum_info_label.setStyleSheet("""
QLabel {
background-color: #f0f0f0;
border: 2px solid #666666;
padding: 8px;
color: black;
border-radius: 5px;
}
""")
self.spectrum_info_label.setAlignment(Qt.AlignCenter)
self.spectrum_info_label.setText("Loading spectrum info...")
self.redshiftSpin = QDoubleSpinBox()
self.redshiftSpin.setMinimum(self.z_min)
self.redshiftSpin.setMaximum(self.z_max)
self.redshiftSpin.setDecimals(4)
self.redshiftSpin.setSingleStep(0.001)
self.redshiftSpin.valueChanged.connect(self.spin_changed)
self.slider.setStyleSheet("""
QSlider {
background-color: white;
}
QSlider::groove:horizontal {
border: 1px solid #999999;
height: 8px;
background: #b0c4de;
margin: 2px 0;
}
QSlider::handle:horizontal {
background: #6495ed;
border: 1px solid #5c5c5c;
width: 18px;
margin: -2px 0;
border-radius: 3px;
}
""")
self.plot_next()
# ------------------------------------------------------------------
# Utility methods
def _load_spec(self, index):
"""Load a spectrum by index from ``spectra``."""
if self.speclist is not None:
filename = self.speclist[index]
spec = self.SpecClass(filename)
else:
spec = self.SpecClass(self.specfile, ext=index + 1)
self._ensure_spec_defaults(spec)
return spec
def _ensure_spec_defaults(self, spec):
"""Ensure common attributes exist on ``spec``."""
if not hasattr(spec, 'z_vi'):
spec.z_vi = getattr(spec, 'redshift', 0.0)
if not hasattr(spec, 'z_ph'):
spec.z_ph = getattr(spec, 'redshift', 0.0)
if not hasattr(spec, 'z_gaia'):
spec.z_gaia = getattr(spec, 'redshift', 0.0)
if not hasattr(spec, 'objid'):
spec.objid = self.counter
if not hasattr(spec, 'objname'):
spec.objname = 'Unknown'
[docs]
def update_slider_and_spin(self):
spec = self.spec
initial_z = spec.z_vi if spec.z_vi > 0 else self.z_min
initial_slider_value = int((1/self.base_z_step) * np.log((1+initial_z)/(1+self.z_min)))
self.slider.blockSignals(True)
self.slider.setValue(initial_slider_value)
self.slider.blockSignals(False)
self.redshiftSpin.blockSignals(True)
self.redshiftSpin.setValue(initial_z)
self.redshiftSpin.blockSignals(False)
# ------------------------------------------------------------------
# Slots
[docs]
def slider_changed(self, slider_value):
z = np.exp(self.base_z_step * slider_value) * (1+self.z_min) - 1
self.spec.z_vi = z
self.redshiftSpin.blockSignals(True)
self.redshiftSpin.setValue(z)
self.redshiftSpin.blockSignals(False)
self.clear()
self.plot_single()
[docs]
def spin_changed(self, z):
slider_value = int((1/self.base_z_step) * np.log((1+z)/(1+self.z_min)))
self.slider.blockSignals(True)
self.slider.setValue(slider_value)
self.slider.blockSignals(False)
self.spec.z_vi = z
self.clear()
self.plot_single()
# ------------------------------------------------------------------
# Plotting helpers
[docs]
def plot_single(self):
spec = self.spec
if spec.z_vi == 0 and spec.z_ph > 0:
spec.z_vi = spec.z_ph
z_vi = spec.z_vi
z_gaia = spec.z_gaia
objname = spec.objname
flux = np.ma.masked_invalid(spec.flux.value)
flux_sigclip = sigma_clip(flux, sigma=10, maxiters=3)
wave = spec.wave.value[~flux_sigclip.mask]
flux = flux_sigclip.data[~flux_sigclip.mask]
err = spec.err[~flux_sigclip.mask]
self.plot(wave, flux, pen='b',
symbol='o', symbolSize=4, symbolPen=None, connect='finite',
symbolBrush='k', antialias=True)
self.wave = wave
self.flux = flux
if getattr(spec, 'telescope', '').lower() == 'euclid':
wave_temp = tb_temp['Wave'].data * (1+z_vi)
idx = np.where((wave_temp >= 12047.4) & (wave_temp <= 18734))
flux_temp = tb_temp['Flux'].data
wave_temp = wave_temp[idx]
flux_temp = flux_temp[idx] / np.mean(flux_temp[idx]) * np.abs(flux.mean()) * 1.5
self.plot(wave_temp, flux_temp, pen=(240, 128, 128), symbol='+',
symbolSize=2, symbolPen=None)
self.setLabel('left', "Flux", units=spec.flux.unit.to_string())
self.setLabel('bottom', "Wavelength", units=spec.wave.unit.to_string())
# Update info label above plot
self.update_spectrum_info_label()
self.autoRange()
[docs]
def update_spectrum_info_label(self):
"""Update the spectrum info label above the plot."""
if not hasattr(self, 'spec'):
return
spec = self.spec
z_vi = getattr(spec, 'z_vi', 0.0)
z_gaia = getattr(spec, 'z_gaia', 0.0)
objname = getattr(spec, 'objname', 'Unknown')
objid = getattr(spec, 'objid', 'Unknown')
# Calculate the display number based on which spectrum we're actually showing
if hasattr(self, '_displaying_spectrum_number'):
current_spectrum_number = self._displaying_spectrum_number
else:
# Fallback to counter (this handles template updates and other cases)
current_spectrum_number = self.counter if hasattr(self, 'len_list') else 1
message = f"Spectrum {current_spectrum_number}/{self.len_list}" if hasattr(self, 'len_list') else ""
text_content = f"{message} ID: {objid} z_vi = {z_vi:.4f}, z_gaia = {z_gaia:.4f}"
if hasattr(self, 'spectrum_info_label'):
self.spectrum_info_label.setText(text_content)
# ------------------------------------------------------------------
# Navigation
[docs]
def plot_next(self):
if self.counter >= self.len_list:
print("No more spectra to plot.")
return
self.clear()
spec = self._load_spec(self.counter)
self.spec = spec
if spec.objid in self.history:
spec.z_vi = self.history[spec.objid][4]
class_vi = self.history[spec.objid][3]
print(f"\tVisual class from history: {class_vi}.")
# Set the display number before plotting (counter + 1 because we haven't incremented yet)
self._displaying_spectrum_number = self.counter + 1
print(f"Spectrum {self._displaying_spectrum_number}/{self.len_list}.")
self.update_slider_and_spin()
self.plot_single()
# Emit coordinate change signal for future extensions
if hasattr(self.spec, 'ra') and hasattr(self.spec, 'dec'):
self.coordinate_changed.emit(self.spec.ra, self.spec.dec)
self.counter += 1
[docs]
def plot_previous(self):
if self.counter > 1:
self.clear()
spec = self._load_spec(self.counter - 2)
self.spec = spec
if spec.objid in self.history:
spec.z_vi = self.history[spec.objid][4]
class_vi = self.history[spec.objid][3]
print(f"\tVisual class from history: {class_vi}.")
self.counter -= 1
# Set the display number before plotting (counter is correct after decrement)
self._displaying_spectrum_number = self.counter
print(f"Spectrum {self._displaying_spectrum_number}/{self.len_list}.")
self.update_slider_and_spin()
self.plot_single()
# Emit coordinate change signal for future extensions
if hasattr(self.spec, 'ra') and hasattr(self.spec, 'dec'):
self.coordinate_changed.emit(self.spec.ra, self.spec.dec)
else:
print("No previous spectrum to plot.")
# ------------------------------------------------------------------
# Event handlers
[docs]
def keyPressEvent(self, event):
spec = self.spec
if event.key() == Qt.Key_Q:
if spec.objid not in self.history:
self.history[spec.objid] = [spec.objname, spec.ra, spec.dec,
'QSO(Default)', spec.z_vi]
else:
# Update existing entry with current z_vi (preserves classification but updates redshift)
self.history[spec.objid][4] = spec.z_vi
if self.counter < self.len_list:
self.clear()
self.plot_next()
else:
print("No more spectra to plot.")
if (self.counter-1) % 50 == 0:
print("Saving temp file to csv (n={})...".format(self.counter))
temp_filename = f"vi_temp_{self.counter-1}.csv"
df_new = pd.DataFrame.from_dict(self.history, orient='index')
df_new.reset_index(inplace=True)
df_new.rename(columns={'index': 'objid', 0: 'objname', 1: 'ra',
2: 'dec', 3: 'class_vi', 4: 'z_vi'},
inplace=True)
df_new = df_new[['objid', 'objname', 'ra', 'dec', 'class_vi', 'z_vi']]
df_new.to_csv(temp_filename, index=False)
if event.key() == Qt.Key_M:
mouse_pos = self.mapFromGlobal(QCursor.pos())
self.vb = self.getViewBox()
mouse_pos = self.vb.mapSceneToView(mouse_pos)
print(f"Mouse position - Wavelength: {mouse_pos.x():.2f}, Flux: {mouse_pos.y():.2e}")
if event.key() == Qt.Key_Space:
mouse_pos = self.mapFromGlobal(QCursor.pos())
self.vb = self.getViewBox()
wave = self.vb.mapSceneToView(mouse_pos).x()
idx = np.abs(self.wave - wave).argmin()
wave = self.wave[idx]
flux = self.flux[idx]
annotation_text = pg.TextItem(
text="Wavelength: {0:.2f} Flux: {1:.2e}".format(wave, flux),
anchor=(0, 0), color='r', border='w',
fill=(255, 255, 255, 200))
annotation_text.setFont(QFont("Arial", 18, QFont.Bold))
annotation_text.setPos(wave, flux)
self.addItem(annotation_text)
print("Wavelength: {0:.2f} Flux: {1:.2e}".format(wave, flux))
if event.key() == Qt.Key_S:
print("\tClass: STAR.")
self.history[spec.objid] = [spec.objname, spec.ra, spec.dec,
'STAR', 0.0]
if event.key() == Qt.Key_G:
print("\tClass: GALAXY.")
self.history[spec.objid] = [spec.objname, spec.ra, spec.dec,
'GALAXY', spec.z_vi]
if event.key() == Qt.Key_A:
print("\tClass: QSO(AGN).")
self.history[spec.objid] = [spec.objname, spec.ra, spec.dec,
'QSO', spec.z_vi]
if event.key() == Qt.Key_U:
print("\tClass: UNKNOWN.")
self.history[spec.objid] = [spec.objname, spec.ra, spec.dec,
'UNKNOWN', 0.0]
if event.key() == Qt.Key_L:
print("\tClass: LIKELY/Unusual QSO.")
self.history[spec.objid] = [spec.objname, spec.ra, spec.dec,
'LIKELY', spec.z_vi]
if event.key() == Qt.Key_R:
self.clear()
self.plot_single()
if event.modifiers() & Qt.ControlModifier:
if event.key() == Qt.Key_R:
self.clear()
self.spec = self._load_spec(self.counter - 1)
# For reload, display number is current counter
self._displaying_spectrum_number = self.counter
self.update_slider_and_spin()
self.plot_single()
if event.key() == Qt.Key_Right:
self.clear()
self.counter = self.len_list - 1
self.plot_next()
if event.key() == Qt.Key_Left:
self.clear()
self.counter = 0
self.plot_next()
if event.key() == Qt.Key_B:
self.clear()
self.counter = len(self.history) - 1
self.plot_next()
if event.key() == Qt.Key_Left:
self.plot_previous()
if event.key() == Qt.Key_Right:
self.plot_next()
[docs]
class PGSpecPlotApp(QApplication):
"""Standalone application running :class:`PGSpecPlot`."""
@staticmethod
def _normalize_objid(value):
"""Normalize objid loaded from CSV.
Keeps integers as int (for legacy FITS workflows) and keeps non-numeric
IDs (e.g. SPARCL UUIDs) as str.
"""
if value is None or (isinstance(value, float) and np.isnan(value)):
return value
if isinstance(value, (np.integer, int)):
return int(value)
if isinstance(value, (np.floating, float)):
if float(value).is_integer():
return int(value)
return str(value)
s = str(value)
try:
return int(s)
except Exception:
return s
def __init__(self, spectra, SpecClass=SpecEuclid1d,
output_file='vi_output.csv', z_max=5.0, load_history=False):
super().__init__(sys.argv)
self.output_file = output_file
self.spectra = spectra
self.SpecClass = SpecClass
if load_history and os.path.exists(self.output_file):
print(f"Loading history from {self.output_file} ...")
df = pd.read_csv(self.output_file)
if 'vi_class' in df.columns:
df.rename(columns={'vi_class': 'class_vi'}, inplace=True)
history_dict = {}
for _, row in df.iterrows():
objid = self._normalize_objid(row['objid'])
history_dict[objid] = [row['objname'], row['ra'],
row['dec'], row['class_vi'],
row['z_vi']]
initial_counter = df.shape[0]
else:
history_dict = {}
initial_counter = 0
self.plot = PGSpecPlot(
self.spectra, self.SpecClass,
initial_counter=initial_counter,
z_max=z_max,
history_dict=history_dict)
self.len_list = self.plot.len_list
self.make_layout()
self.aboutToQuit.connect(self.save_dict_todf)
[docs]
def make_layout(self):
layout = pg.LayoutWidget()
layout.resize(1200, 800)
layout.setWindowTitle(f"PGSpecPlot - Spectra Viewer (v{viewer_version})")
if self.plot.counter < self.len_list + 1:
# Instructions with comprehensive keyboard shortcuts
instruction_text = (
"Navigation: 'Q' next spectrum, Left/Right arrows previous/next | "
"Classification: 'A' QSO(AGN), 'S' STAR, 'G' GALAXY, 'U' UNKNOWN, 'L' LIKELY | "
"Tools: 'Space' wavelength info, 'M' mouse position, 'R' reset zoom | "
"Advanced: Ctrl+R reload, Ctrl+Left first, Ctrl+Right last, Ctrl+B resume from history"
)
toplabel = layout.addLabel(instruction_text, row=0, col=0, colspan=2)
toplabel.setFont(QFont("Arial", 13))
toplabel.setMinimumHeight(60)
toplabel.setMaximumHeight(80)
toplabel.setAlignment(Qt.AlignLeft | Qt.AlignTop)
toplabel.setStyleSheet("background-color: white;color: black;")
toplabel.setFrameStyle(QFrame.Panel | QFrame.Raised)
toplabel.setWordWrap(True)
toplabel.setMidLineWidth(2)
toplabel.setFrameShadow(QFrame.Sunken)
toplabel.setMargin(5)
toplabel.setIndent(5)
toplabel.setWordWrap(True)
# Add spectrum info label above plot
layout.addWidget(self.plot.spectrum_info_label, row=1, col=0, colspan=2)
layout.addWidget(self.plot, row=2, col=0, colspan=2)
slider_container = QWidget()
slider_layout = QHBoxLayout()
slider_layout.addWidget(self.plot.slider)
slider_layout.addWidget(self.plot.redshiftSpin)
slider_container.setLayout(slider_layout)
layout.addWidget(slider_container, row=3, col=0, colspan=2)
self.layout = layout
self.layout.show()
[docs]
def keyPressEvent(self, event):
self.plot.keyPressEvent(event)
[docs]
def mousePressEvent(self, event):
self.plot.mousePressEvent(event)
[docs]
def save_dict_todf(self):
df_new = pd.DataFrame.from_dict(self.plot.history, orient='index')
df_new.reset_index(inplace=True)
df_new.rename(columns={'index': 'objid', 0: 'objname', 1: 'ra',
2: 'dec', 3: 'class_vi', 4: 'z_vi'},
inplace=True)
df_new = df_new[['objid', 'objname', 'ra', 'dec', 'class_vi', 'z_vi']]
df_new.to_csv(self.output_file, index=False)
[docs]
class PGSpecPlotThread(QThread):
"""Run :class:`PGSpecPlotApp` in a separate thread."""
def __init__(self, spectra=None, SpecClass=SpecEuclid1d, specfile=None, **kwargs):
super().__init__()
# Handle backward compatibility: if specfile is provided but spectra is not
if spectra is None and specfile is not None:
self.spectra = specfile
elif spectra is not None:
self.spectra = spectra
else:
raise ValueError("Either 'spectra' or 'specfile' must be provided")
self.SpecClass = SpecClass
self.app = PGSpecPlotApp(self.spectra, self.SpecClass, **kwargs)
[docs]
def run(self):
exit_code = self.app.exec_()
sys.exit(exit_code)