from PySide6.QtGui import QCursor, QFont, QPixmap
from PySide6.QtCore import Qt, QThread, Signal, QTimer, QThreadPool, QRunnable
from PySide6.QtWidgets import (QApplication, QFrame, QWidget, QSlider, QHBoxLayout,
QVBoxLayout, QGridLayout, QDoubleSpinBox, QPushButton,
QLabel, QButtonGroup, QRadioButton, QGroupBox,
QFileDialog, QMessageBox, QComboBox, QProgressBar, QSpinBox,
QScrollArea, QCheckBox, QDialog, QTextEdit, QSplitter,
QLineEdit)
import sys
from ..basemodule import *
import pyqtgraph as pg
import numpy as np
from astropy.table import Table
from astropy.io import fits
from astropy.stats import sigma_clip
from astropy.coordinates import SkyCoord
from astropy import units as u
import pandas as pd
from importlib.resources import files
from pathlib import Path
import os
import requests
from urllib.parse import urlencode
import tempfile
from PIL import Image
import io
import json
import hashlib
import time
import re
import concurrent.futures
import threading
from datetime import datetime, timezone
import getpass
from specutils import Spectrum
from importlib.metadata import PackageNotFoundError, version as dist_version
from ..auxmodule.cutout_download import (
EUCLID_CUTOUT_SURVEYS,
fetch_cutout,
get_cache_filename,
is_no_data_error,
is_valid_cutout_target,
load_cutout_from_cache,
predownload_cutouts,
print_cli_progress,
save_cutout_to_cache,
)
from ..auxmodule.external_redshift import normalize_redshift_lookup_key
# locate the data files in the package
data_path = Path(files("specbox").joinpath("data/templates"))
fits_file = data_path / "qso1" / "optical_nir_qso_template_v1.fits"
ragn_dr1_template_file = data_path / "qso1" / "ragn_dr1.fits"
type2_template_file = data_path / "qso2" / "type2_quasar_composite.csv"
type2_euclid_template_file = data_path / "qso2" / "ragn_na.fits"
tb_temp = Table.read(str(fits_file))
tb_temp.rename_columns(['wavelength', 'flux'], ['Wave', 'Flux'])
try:
viewer_version = dist_version("specbox")
except PackageNotFoundError:
viewer_version = "0.0.0"
# Rest-frame emission lines used for template annotations.
# All values are in Angstrom.
_TEMPLATE_EMISSION_LINES = [
("Ly α", 1215.67),
("C IV", 1549.06),
("C III]", 1908.73),
("Mg II", 2798.75),
("[O II]", 3728.48),
("Hβ", 4862.68),
("[O III]", 4960.30),
("[O III]", 5008.24),
("Hα", 6564.61),
("O I", 8448.7),
("[S III]", 9071.1),
("[S III]", 9533.2),
("Pa δ", 10052.1),
("He I", 10833.2),
("Pa γ", 10941.1),
("O I", 11290.0),
("Pa β", 12821.6),
]
_TEMPLATE_COLOR = (220, 0, 0, 230)
_CANONICAL_CLASS_TO_DISPLAY = {
"QSO_DEFAULT": "QSO(Default)",
"QSO": "QSO",
"QSO_NARROW": "QSO(Narrow)",
"QSO_BAL": "QSO(BAL)",
"QSO_FELOBAL": "QSO(FeLoBAL)",
"LIKELY_Q": "LIKELY_Q",
"GALAXY": "GALAXY",
"STAR": "STAR",
"UNKNOWN": "UNKNOWN",
"BAD": "BAD",
}
_CLASS_LABEL_ALIASES = {
"QSO(Default)": "QSO_DEFAULT",
"QSO_DEFAULT": "QSO_DEFAULT",
"QSO": "QSO",
"QSO(narrow)": "QSO_NARROW",
"QSO(Narrow)": "QSO_NARROW",
"QSO_NARROW": "QSO_NARROW",
"QSO(BAL)": "QSO_BAL",
"QSO_BAL": "QSO_BAL",
"QSO(FeLoBAL)": "QSO_FELOBAL",
"QSO_FELOBAL": "QSO_FELOBAL",
"LIKELY": "LIKELY_Q",
"LIKELY_Q": "LIKELY_Q",
"GALAXY": "GALAXY",
"STAR": "STAR",
"UNKNOWN": "UNKNOWN",
"BAD": "BAD",
}
[docs]
def normalize_class_label(value):
if value is None:
return ""
text = str(value).strip()
if not text or text.lower() == "nan":
return ""
return _CLASS_LABEL_ALIASES.get(text, text)
[docs]
def display_class_label(value):
canonical = normalize_class_label(value)
return _CANONICAL_CLASS_TO_DISPLAY.get(canonical, canonical)
[docs]
def default_reviewer_username():
try:
username = getpass.getuser()
except Exception:
username = ""
return str(username or "").strip()
[docs]
def normalize_data_release(value, *, aimsz_review=False):
if value is None:
return None
text = str(value).strip()
if not text or text.lower() == "nan":
return None
if aimsz_review and "desi" in text.lower() and "dr1" in text.lower():
return "DESI-DR1"
return text
[docs]
def load_template_table(template_path):
template_path = Path(template_path)
suffix = template_path.suffix.lower()
if suffix == ".csv":
table = Table.read(str(template_path), format="ascii.csv")
else:
table = Table.read(str(template_path))
columns = {str(col).lower(): col for col in table.colnames}
wave_col = columns.get("wavelength", table.colnames[0] if len(table.colnames) > 0 else None)
flux_col = columns.get("flux", table.colnames[1] if len(table.colnames) > 1 else None)
if wave_col is None or flux_col is None:
raise ValueError(f"Template has unexpected columns: {template_path}")
return {
"wave": np.asarray(table[wave_col], dtype=float),
"flux": np.asarray(table[flux_col], dtype=float),
}
def _is_dataframe_backed_spectrum_path(path):
return SpecEuclid1d._is_dataframe_backed_path(path)
def _choose_dual_pair_key_column(rgs_df, bgs_df):
for column in ("extname", "objid"):
if column in rgs_df.columns and column in bgs_df.columns:
return column
return None
def _ordered_shared_dual_pair_keys(rgs_file, bgs_file):
if not (_is_dataframe_backed_spectrum_path(rgs_file) and _is_dataframe_backed_spectrum_path(bgs_file)):
return [], None
rgs_df = SpecPandasRow._read_dataframe_file(rgs_file, file_format="parquet")
bgs_df = SpecPandasRow._read_dataframe_file(bgs_file, file_format="parquet")
key_column = _choose_dual_pair_key_column(rgs_df, bgs_df)
if key_column is None:
return [], None
bgs_keys = {
str(value).strip()
for value in bgs_df[key_column]
if value is not None and str(value).strip() not in ("", "nan", "None")
}
ordered = []
seen = set()
for value in rgs_df[key_column]:
key = str(value).strip()
if key in ("", "nan", "None") or key not in bgs_keys or key in seen:
continue
ordered.append(key)
seen.add(key)
return ordered, key_column
[docs]
class AIMSZReviewPanel(QWidget):
"""Read-only review context plus reviewer-editable fields for AIMS-z."""
qa_flag_changed = Signal(int)
notes_changed = Signal(str)
reviewer_changed = Signal(str)
def __init__(self, default_reviewer=""):
super().__init__()
self._updating = False
self._context_labels = {}
self._default_reviewer = str(default_reviewer or "")
self.setFixedWidth(320)
self._setup_ui()
def _setup_ui(self):
layout = QVBoxLayout()
context_group = QGroupBox("AIMS-z Review Context")
context_layout = QGridLayout()
context_fields = [
("object_id", "object_id"),
("targetid", "targetid"),
("review_priority_tier", "tier"),
("review_score", "score"),
("review_rank_within_tier", "rank"),
("review_slice_label", "slice"),
("z_ref", "z_ref"),
("z_ml_expect", "z_ml"),
("z_pcf_best", "z_pcf"),
("pcf_template_best", "pcf_template"),
("pcf_score_best", "pcf_score"),
]
for row_idx, (field, label) in enumerate(context_fields):
context_layout.addWidget(QLabel(f"{label}:"), row_idx, 0)
value_label = QLabel("-")
value_label.setWordWrap(True)
value_label.setTextInteractionFlags(Qt.TextSelectableByMouse | Qt.TextSelectableByKeyboard)
context_layout.addWidget(value_label, row_idx, 1)
self._context_labels[field] = value_label
context_group.setLayout(context_layout)
layout.addWidget(context_group)
qa_group = QGroupBox("Review QA")
qa_layout = QVBoxLayout()
self.qa_contamination_cb = QCheckBox("Contamination from nearby\nsource(s)")
self.qa_unusable_cb = QCheckBox("Unusable spectrum due to\ndominating artifacts")
qa_checkbox_style = "QCheckBox::indicator { width: 24px; height: 24px; }"
self.qa_contamination_cb.setStyleSheet(qa_checkbox_style)
self.qa_unusable_cb.setStyleSheet(qa_checkbox_style)
qa_layout.addWidget(self.qa_contamination_cb)
qa_layout.addWidget(self.qa_unusable_cb)
qa_group.setLayout(qa_layout)
layout.addWidget(qa_group)
editor_group = QGroupBox("Reviewer Edits")
editor_layout = QGridLayout()
self.notes_label = QLabel("Notes:")
editor_layout.addWidget(self.notes_label, 0, 0)
self.notes_edit = QTextEdit()
self.notes_edit.setMinimumHeight(90)
self.notes_edit.setFocusPolicy(Qt.ClickFocus)
editor_layout.addWidget(self.notes_edit, 0, 1)
self.reviewer_label = QLabel("Reviewer:")
editor_layout.addWidget(self.reviewer_label, 1, 0)
self.reviewer_edit = QLineEdit()
self.reviewer_edit.setFocusPolicy(Qt.ClickFocus)
editor_layout.addWidget(self.reviewer_edit, 1, 1)
self.reviewed_at_title = QLabel("Reviewed at:")
editor_layout.addWidget(self.reviewed_at_title, 2, 0)
self.reviewed_at_label = QLabel("-")
self.reviewed_at_label.setTextInteractionFlags(Qt.TextSelectableByMouse | Qt.TextSelectableByKeyboard)
editor_layout.addWidget(self.reviewed_at_label, 2, 1)
editor_group.setLayout(editor_layout)
layout.addWidget(editor_group)
layout.addStretch()
self.setLayout(layout)
self.qa_contamination_cb.stateChanged.connect(self._emit_qa_flag)
self.qa_unusable_cb.stateChanged.connect(self._emit_qa_flag)
self.notes_edit.textChanged.connect(self._emit_notes)
self.reviewer_edit.textChanged.connect(self._emit_reviewer)
self.hide_notes_editor()
[docs]
def hide_notes_editor(self):
self.notes_label.setVisible(False)
self.notes_edit.setVisible(False)
[docs]
def set_default_reviewer(self, reviewer):
self._default_reviewer = str(reviewer or "")
@staticmethod
def _qa_flag_from_bits(contamination, unusable):
flag = 0
if contamination:
flag |= ImageCutoutWidget.QA_CONTAMINATION_BIT
if unusable:
flag |= ImageCutoutWidget.QA_UNUSABLE_BIT
return int(flag)
def _emit_qa_flag(self):
if self._updating:
return
self.qa_flag_changed.emit(
self._qa_flag_from_bits(
self.qa_contamination_cb.isChecked(),
self.qa_unusable_cb.isChecked(),
)
)
def _emit_notes(self):
if self._updating:
return
self.notes_changed.emit(self.notes_edit.toPlainText())
def _emit_reviewer(self, text):
if self._updating:
return
self.reviewer_changed.emit(text)
[docs]
def set_review_context(self, spec, record=None):
self._updating = True
try:
for field, label in self._context_labels.items():
value = getattr(spec, field, None)
if value is None and isinstance(record, dict):
value = record.get(field)
try:
is_missing = value is None or (isinstance(value, str) and value.strip() == "") or pd.isna(value)
except Exception:
is_missing = value is None or (isinstance(value, str) and value.strip() == "")
label.setText("-" if is_missing else str(value))
qa_flag = 0 if not isinstance(record, dict) else int(record.get("qa_flag", 0) or 0)
self.qa_contamination_cb.setChecked((qa_flag & ImageCutoutWidget.QA_CONTAMINATION_BIT) != 0)
self.qa_unusable_cb.setChecked((qa_flag & ImageCutoutWidget.QA_UNUSABLE_BIT) != 0)
reviewer = self._default_reviewer if not isinstance(record, dict) else str(record.get("reviewer", "") or self._default_reviewer)
self.notes_edit.setPlainText("" if not isinstance(record, dict) else str(record.get("notes", "") or ""))
self.reviewer_edit.setText(reviewer)
reviewed_at = "" if not isinstance(record, dict) else str(record.get("reviewed_at", "") or "")
self.reviewed_at_label.setText(reviewed_at if reviewed_at else "-")
finally:
self._updating = False
[docs]
def get_qa_flag(self):
flag = 0
if self.qa_contamination_cb.isChecked():
flag |= self.QA_CONTAMINATION_BIT
if self.qa_unusable_cb.isChecked():
flag |= self.QA_UNUSABLE_BIT
return int(flag)
[docs]
def set_qa_flag(self, qa_flag):
try:
flag = int(qa_flag)
except Exception:
flag = 0
self.qa_contamination_cb.blockSignals(True)
self.qa_unusable_cb.blockSignals(True)
self.qa_contamination_cb.setChecked((flag & self.QA_CONTAMINATION_BIT) != 0)
self.qa_unusable_cb.setChecked((flag & self.QA_UNUSABLE_BIT) != 0)
self.qa_contamination_cb.blockSignals(False)
self.qa_unusable_cb.blockSignals(False)
# Removed load_local_cutouts method since we removed the Load Local button
[docs]
def load_online_cutouts(self, ra, dec, objid=None):
"""Load cutouts from HiPS2FITS service for given coordinates."""
if not self.auto_fetch_cb.isChecked():
return
if ra is None or dec is None or np.isnan(ra) or np.isnan(dec):
self.coord_label.setText("RA, DEC = -, -")
self.add_status_message("No valid coordinates available")
return
# Update coordinate display
self.coord_label.setText(f"RA, DEC = {ra:.6f}, {dec:.6f}")
# Prevent multiple simultaneous fetches
if self._fetch_in_progress:
return
self._fetch_in_progress = True
self.clear_images()
self.progress.setVisible(True)
self.progress.setRange(0, 0) # Indeterminate progress
# Only fetch color for now to avoid loading issues
euclid_surveys = list(EUCLID_CUTOUT_SURVEYS)
size_arcsec = float(self.size_combo.currentText())
# Fetch bands sequentially to avoid threading issues
try:
for survey, band_name in euclid_surveys:
try:
# Check cache first
cached_data = None
if objid:
cached_data = self.load_cutout_from_cache(objid, survey)
if cached_data is not None:
self.add_status_message(f"Loading {band_name} from cache...")
self.add_image_from_rgb_array(cached_data, band_name)
else:
self.add_status_message(f"Fetching {band_name}...")
# Use JPEG format for faster loading
result = fetch_cutout(
ra=ra,
dec=dec,
survey_name=survey,
size_arcsec=size_arcsec,
width=150,
height=150,
)
if result is not None:
# Save to cache
if objid:
self.save_cutout_to_cache(objid, survey, result)
# result is a numpy array for JPEG format (RGB)
self.add_image_from_rgb_array(result, band_name)
else:
self.add_status_message(f"No data returned for {band_name}")
except Exception as e:
self.add_status_message(f"Error fetching {band_name}: {str(e)}")
except Exception as e:
self.add_status_message(f"General error in image fetching: {str(e)}")
finally:
self.progress.setVisible(False)
self._fetch_in_progress = False
[docs]
def add_image_from_pil(self, pil_image, label=""):
"""Add PIL image directly to display (for JPEG format)."""
try:
# Convert PIL image to QPixmap
with io.BytesIO() as buffer:
pil_image.save(buffer, format='PNG')
pixmap = QPixmap()
pixmap.loadFromData(buffer.getvalue())
self.add_image_widget(pixmap, label)
except Exception as e:
self.add_status_message(f"Error processing image {label}: {str(e)}")
[docs]
def add_image_from_array(self, data, label=""):
"""Add image from numpy array to display."""
try:
# Simple scaling - could be improved
data_scaled = np.nan_to_num(data)
if np.all(data_scaled == 0):
self.add_status_message(f"Empty image data for {label}")
return
# Use arcsinh scaling like in notebook for VIS band
if "VIS" in label or "$I_" in label:
data_scaled = np.arcsinh(data_scaled * 500)
else:
data_scaled = np.arcsinh(data_scaled)
vmin, vmax = np.percentile(data_scaled[data_scaled > 0], [1, 99])
if vmax > vmin:
data_scaled = np.clip((data_scaled - vmin) / (vmax - vmin) * 255, 0, 255)
else:
data_scaled = np.zeros_like(data_scaled)
# Convert to QPixmap
height, width = data_scaled.shape
image = Image.fromarray(data_scaled.astype(np.uint8), mode='L')
# Convert PIL to QPixmap
with io.BytesIO() as buffer:
image.save(buffer, format='PNG')
pixmap = QPixmap()
pixmap.loadFromData(buffer.getvalue())
self.add_image_widget(pixmap, label)
except Exception as e:
self.add_status_message(f"Error processing image {label}: {str(e)}")
[docs]
def add_image_from_rgb_array(self, rgb_data, label=""):
"""Add RGB image from numpy array (for JPEG format)."""
try:
# rgb_data is (height, width, 3) uint8 array
if len(rgb_data.shape) != 3 or rgb_data.shape[2] != 3:
self.add_status_message(f"Expected RGB array for {label}, got shape {rgb_data.shape}")
return
# Convert numpy array directly to PIL Image (RGB mode)
image = Image.fromarray(rgb_data, mode='RGB')
# Convert PIL to QPixmap
with io.BytesIO() as buffer:
image.save(buffer, format='PNG')
pixmap = QPixmap()
pixmap.loadFromData(buffer.getvalue())
self.add_image_widget(pixmap, label)
except Exception as e:
self.add_status_message(f"Error processing RGB image {label}: {str(e)}")
[docs]
def create_composite_image(self, ra, dec, size_arcsec):
"""Create RGB composite from multiple bands."""
try:
# For now, just add a placeholder for composite
self.add_status_message("Composite image creation not yet implemented")
except Exception as e:
self.add_status_message(f"Error creating composite: {str(e)}")
[docs]
def add_image_from_fits(self, fits_path, label=""):
"""Add FITS image to display."""
try:
with fits.open(fits_path) as hdul:
data = hdul[0].data
if data is not None:
# Convert to displayable image
# Simple scaling - could be improved
data_scaled = np.nan_to_num(data)
vmin, vmax = np.percentile(data_scaled, [1, 99])
data_scaled = np.clip((data_scaled - vmin) / (vmax - vmin) * 255, 0, 255)
# Convert to QPixmap
height, width = data_scaled.shape
image = Image.fromarray(data_scaled.astype(np.uint8), mode='L')
# Convert PIL to QPixmap
with io.BytesIO() as buffer:
image.save(buffer, format='PNG')
pixmap = QPixmap()
pixmap.loadFromData(buffer.getvalue())
self.add_image_widget(pixmap, label)
except Exception as e:
self.add_status_message(f"Error loading FITS {fits_path}: {str(e)}")
[docs]
def add_image_from_file(self, file_path):
"""Add image from file (FITS or regular image formats)."""
path = Path(file_path)
if path.suffix.lower() in ['.fits', '.fit']:
self.add_image_from_fits(file_path, path.stem)
else:
# Regular image file
pixmap = QPixmap(file_path)
if not pixmap.isNull():
self.add_image_widget(pixmap, path.stem)
else:
self.add_status_message(f"Could not load image: {file_path}")
[docs]
def add_status_message(self, message):
"""Add status message to display."""
label = QLabel(message)
label.setWordWrap(True)
label.setStyleSheet("color: green; font-style: italic;")
self.images_layout.addWidget(label)
# Auto-remove after 3 seconds with weak reference
import weakref
weak_label = weakref.ref(label)
QTimer.singleShot(3000, lambda: self.remove_widget(weak_label()) if weak_label() else None)
[docs]
def clear_images(self):
"""Clear all displayed images."""
while self.images_layout.count():
child = self.images_layout.takeAt(0)
if child.widget():
child.widget().deleteLater()
[docs]
def get_cache_filename(self, objid, survey_name):
"""Get cache filename for an object and survey."""
return get_cache_filename(self.buffer_dir, objid, survey_name)
[docs]
def save_cutout_to_cache(self, objid, survey_name, image_data):
"""Save cutout image to cache."""
if not self.buffer_dir:
return
try:
save_cutout_to_cache(self.buffer_dir, objid, survey_name, image_data)
except Exception as e:
cache_file = self.get_cache_filename(objid, survey_name)
print(f"Error saving cache file {cache_file}: {e}")
[docs]
def load_cutout_from_cache(self, objid, survey_name):
"""Load cutout from cache if exists."""
if not self.buffer_dir:
return None
try:
return load_cutout_from_cache(self.buffer_dir, objid, survey_name)
except Exception as e:
cache_file = self.get_cache_filename(objid, survey_name)
print(f"Error loading cache file {cache_file}: {e}")
return None
[docs]
def prefetch_cutouts_background(self, spectrum_list, current_index, num_prefetch=None):
"""Prefetch cutouts for upcoming spectra in the background."""
if not self.buffer_dir or current_index >= len(spectrum_list) - 1:
return
def prefetch_worker():
# Download all remaining cutouts if num_prefetch is None
if num_prefetch is None:
end_index = len(spectrum_list)
else:
end_index = min(current_index + num_prefetch + 1, len(spectrum_list))
surveys = list(EUCLID_CUTOUT_SURVEYS)
for i in range(current_index + 1, end_index):
try:
# Load spectrum to get coordinates and objid
spec_data = spectrum_list[i]
if hasattr(spec_data, 'ra') and hasattr(spec_data, 'dec') and hasattr(spec_data, 'objid'):
ra, dec, objid = spec_data.ra, spec_data.dec, spec_data.objid
else:
continue
# Check if cutout already exists
for survey, _ in surveys:
if not is_valid_cutout_target(objid, ra, dec):
continue
if self.load_cutout_from_cache(objid, survey) is not None:
continue # Already cached
# Download cutout
try:
result = fetch_cutout(
ra=ra,
dec=dec,
survey_name=survey,
size_arcsec=10,
width=150,
height=150,
)
if result is not None:
self.save_cutout_to_cache(objid, survey, result)
time.sleep(0.1)
except Exception as e:
if not is_no_data_error(e):
print(f"Error prefetching cutout for object {objid}: {e}")
time.sleep(0.1)
except Exception as e:
print(f"Error processing spectrum {i} for prefetch: {e}")
# Run prefetch in background thread
threading.Thread(target=prefetch_worker, daemon=True).start()
[docs]
class TemplateManager:
"""Manages spectrum templates."""
def __init__(self):
self.templates = {}
self.current_template = "Type 1"
self.load_default_templates()
[docs]
def load_default_templates(self):
"""Load default templates."""
# Load Type 1 template (existing)
self.templates["Type 1"] = {
'wave': tb_temp['Wave'].data,
'flux': tb_temp['Flux'].data,
'description': 'Type 1 AGN/QSO Template'
}
# Load ragn_dr1 template from packaged FITS file when available.
self.templates["ragn_dr1"] = None
try:
if ragn_dr1_template_file.exists():
ragn_template = load_template_table(ragn_dr1_template_file)
self.templates["ragn_dr1"] = {
'wave': ragn_template['wave'],
'flux': ragn_template['flux'],
'description': 'ragn_dr1 AGN/QSO Template'
}
else:
print(f"ragn_dr1 template file not found: {ragn_dr1_template_file}")
except Exception as e:
print(f"Failed to load ragn_dr1 template from {ragn_dr1_template_file}: {e}")
# Load Type 2 template from packaged CSV file when available.
self.templates["Type 2"] = None
try:
if type2_template_file.exists():
type2_template = load_template_table(type2_template_file)
self.templates["Type 2"] = {
'wave': type2_template['wave'],
'flux': type2_template['flux'],
'description': 'Type 2 AGN/QSO Template'
}
else:
print(f"Type 2 template file not found: {type2_template_file}")
except Exception as e:
print(f"Failed to load Type 2 template from {type2_template_file}: {e}")
self.templates["type2_euclid"] = None
try:
if type2_euclid_template_file.exists():
type2_euclid_template = load_template_table(type2_euclid_template_file)
self.templates["type2_euclid"] = {
'wave': type2_euclid_template['wave'],
'flux': type2_euclid_template['flux'],
'description': 'Type 2 Euclid AGN/QSO Template (ragn_na)'
}
else:
print(f"type2_euclid template file not found: {type2_euclid_template_file}")
except Exception as e:
print(f"Failed to load type2_euclid template from {type2_euclid_template_file}: {e}")
[docs]
def get_template(self, template_name):
"""Get template by name."""
return self.templates.get(template_name)
[docs]
def get_available_templates(self):
"""Get list of available template names."""
return [name for name, template in self.templates.items() if template is not None]
[docs]
def add_template(self, name, wave, flux, description=""):
"""Add new template."""
self.templates[name] = {
'wave': wave,
'flux': flux,
'description': description
}
[docs]
class PGSpecPlotEnhanced(pg.PlotWidget):
"""Enhanced version of PGSpecPlot with image cutouts and template switching."""
coordinate_changed = Signal(float, float) # Signal for coordinate updates
current_spec_changed = Signal()
record_committed = Signal(object)
def __init__(self, spectra, SpecClass=SpecEuclid1d, initial_counter=0,
z_max=5.0, history_dict=None, euclid_fits=None,
rgs_file=None, bgs_file=None, ext=None, extname=None,
dual_good_pixels_only=False, external_redshift_lookup=None,
external_redshift_key="object_id"):
super().__init__()
self.SpecClass = SpecClass
self._is_aimsz_review = issubclass(SpecClass, SpecAIMSZReview)
self.template_manager = TemplateManager()
self.euclid_fits = euclid_fits
self.external_redshift_lookup = dict(external_redshift_lookup or {})
self.external_redshift_key = str(external_redshift_key or "object_id")
self._euclid_overlay_cache = {}
self.dual_rgs_file = rgs_file
self.dual_bgs_file = bgs_file
self.dual_ext = ext
self.dual_extname = extname
self.dual_good_pixels_only = bool(dual_good_pixels_only)
self._dual_mode = bool(self.dual_rgs_file and self.dual_bgs_file)
self._dual_parquet_mode = self._dual_mode and _is_dataframe_backed_spectrum_path(self.dual_rgs_file) and _is_dataframe_backed_spectrum_path(self.dual_bgs_file)
self._dual_pair_keys = []
self._dual_pair_key_column = None
self.spec_dual = None
self._legend = None
self._observed_wmin = None
self._observed_wmax = None
self._annotation_wave = None
self._annotation_flux = None
self._view_lock_active = False
self._locked_view_range = None
self._suspend_view_lock_updates = False
self.downsampling_enabled = False
self._downsampling_auto = False
self._downsampling_method = "mean"
self._downsampling_clip_to_view = True
# ``spectra`` can either be a FITS file containing multiple extensions
# or a list of individual spectrum files.
if self._dual_mode:
self.speclist = None
self.specfile = self.dual_rgs_file
if self._dual_parquet_mode:
self._dual_pair_keys, self._dual_pair_key_column = _ordered_shared_dual_pair_keys(
self.dual_rgs_file,
self.dual_bgs_file,
)
if self.dual_ext is not None or self.dual_extname not in (None, ""):
self.len_list = 1
else:
self.len_list = len(self._dual_pair_keys)
elif self.dual_ext is not None or self.dual_extname not in (None, ""):
self.len_list = 1
else:
rgs_count = self._count_hdus_in_file(self.dual_rgs_file)
bgs_count = self._count_hdus_in_file(self.dual_bgs_file)
if rgs_count != bgs_count:
print(
f"Warning: RGS/BGS extension count mismatch ({rgs_count} vs {bgs_count}); "
f"using min={min(rgs_count, bgs_count)}."
)
self.len_list = min(rgs_count, bgs_count)
elif isinstance(spectra, (list, tuple)):
self.speclist = list(spectra)
self.specfile = None
self.len_list = len(self.speclist)
else:
self.specfile = spectra
if hasattr(SpecClass, "count_in_file"):
self.len_list = int(SpecClass.count_in_file(spectra))
else:
with fits.open(spectra) as hdul:
self.len_list = len(hdul) - 1
self.speclist = None
if initial_counter >= self.len_list:
print("No more spectra to plot.\n\t Plotting the first spectrum.")
initial_counter = 0
self.history = history_dict if history_dict is not None else {}
self.setBackground('w')
self.showGrid(x=True, y=True)
self.setMouseEnabled(x=True, y=True)
self.setLogMode(x=False, y=False)
self.setAspectLocked(False)
left_axis = self.getAxis('left')
left_axis.enableAutoSIPrefix(False)
self.enableAutoRange()
self.vb = self.getViewBox()
self.vb.setMouseMode(self.vb.RectMode)
self.vb.sigRangeChangedManually.connect(self._on_manual_range_changed)
self.z_min = 0.0
self.z_max = z_max
self.base_z_step = 0.001
self.slider_min = 0
self.slider_max = int((1/self.base_z_step) * np.log((1+self.z_max)/(1+self.z_min)))
self.message = ''
self.counter = initial_counter
self._last_committed_objid = None
# Create slider and spin box for redshift control
self.slider = QSlider(Qt.Horizontal)
self.slider.setMinimum(self.slider_min)
self.slider.setMaximum(self.slider_max)
self.slider.setTickPosition(QSlider.TicksBelow)
self.slider.setTickInterval(max(1, int((self.slider_max - self.slider_min) / 10)))
self.slider.valueChanged.connect(self.slider_changed)
# Create spectrum info title box
self.spectrum_info_label = QLabel()
self.spectrum_info_label.setFont(QFont("Arial", 14))
self.spectrum_info_label.setStyleSheet("""
QLabel {
background-color: #f0f0f0;
border: 2px solid #666666;
padding: 8px;
color: black;
border-radius: 5px;
}
""")
self.spectrum_info_label.setAlignment(Qt.AlignCenter)
self.spectrum_info_label.setText("Loading spectrum info...")
self.redshiftSpin = QDoubleSpinBox()
self.redshiftSpin.setMinimum(self.z_min)
self.redshiftSpin.setMaximum(self.z_max)
self.redshiftSpin.setDecimals(4)
self.redshiftSpin.setSingleStep(0.001)
self.redshiftSpin.valueChanged.connect(self.spin_changed)
self.slider.setStyleSheet("""
QSlider {
background-color: white;
}
QSlider::groove:horizontal {
border: 1px solid #999999;
height: 8px;
background: #b0c4de;
margin: 2px 0;
}
QSlider::handle:horizontal {
background: #6495ed;
border: 1px solid #5c5c5c;
width: 18px;
margin: -2px 0;
border-radius: 3px;
}
""")
self.plot_next()
@staticmethod
def _is_sparcl_like_spec(spec):
return ('SpecSparcl' in globals() and isinstance(spec, SpecSparcl)) or (
'SpecAIMSZReview' in globals() and isinstance(spec, SpecAIMSZReview)
)
def _is_current_sparcl_like(self):
spec = getattr(self, "spec", None)
return self._is_sparcl_like_spec(spec)
def _default_class_token(self):
return "QSO_DEFAULT"
def _reviewer_default(self):
reviewer = getattr(self, "_session_reviewer", None)
if reviewer in (None, ""):
reviewer = default_reviewer_username()
return str(reviewer or "")
@staticmethod
def _is_positive_finite(value):
try:
value = float(value)
except Exception:
return False
return np.isfinite(value) and value > 0
def _lookup_external_redshift_value(self, spec):
if not self.external_redshift_lookup:
return None
candidate_attrs = [self.external_redshift_key]
if self.external_redshift_key == "object_id":
candidate_attrs.extend(["object_id", "objid", "extname"])
elif self.external_redshift_key == "objid":
candidate_attrs.extend(["objid", "object_id", "extname"])
elif self.external_redshift_key == "extname":
candidate_attrs.extend(["extname", "objid", "object_id"])
seen = set()
for attr in candidate_attrs:
if attr in seen:
continue
seen.add(attr)
key = normalize_redshift_lookup_key(getattr(spec, attr, None))
if key is None:
continue
if key in self.external_redshift_lookup:
return self.external_redshift_lookup[key]
return None
def _apply_external_redshift_to_spec(self, spec):
z_ref = self._lookup_external_redshift_value(spec)
if z_ref is None:
return None
spec.z_ref = z_ref
return z_ref
def _apply_external_redshift_overlay(self, spec, dual=None):
applied = None
if dual is not None:
for arm in (getattr(dual, "rgs", None), getattr(dual, "bgs", None)):
if arm is None:
continue
z_ref = self._apply_external_redshift_to_spec(arm)
if applied is None and z_ref is not None:
applied = z_ref
if applied is None:
applied = self._apply_external_redshift_to_spec(spec)
elif getattr(spec, "z_ref", None) in (None, "", 0, 0.0):
spec.z_ref = applied
return applied
def _prime_initial_redshift(self, spec):
if self._is_positive_finite(getattr(spec, "z_vi", None)):
return
for attr in ("z_ref", "z_temp", "redshift"):
value = getattr(spec, attr, None)
if self._is_positive_finite(value):
spec.z_vi = float(value)
return
def _new_history_record(self, spec, class_vi="", z_vi=None):
return {
"objname": getattr(spec, "objname", "Unknown"),
"ra": getattr(spec, "ra", np.nan),
"dec": getattr(spec, "dec", np.nan),
"class_vi": normalize_class_label(class_vi),
"z_vi": getattr(spec, "z_vi", 0.0) if z_vi is None else z_vi,
"targetid": getattr(spec, "targetid", None),
"data_release": normalize_data_release(getattr(spec, "data_release", None), aimsz_review=self._is_aimsz_review),
"qa_flag": int(getattr(spec, "qa_flag", 0) or 0),
"notes": str(getattr(spec, "notes", "") or ""),
"reviewer": str(getattr(spec, "reviewer", "") or self._reviewer_default()),
"reviewed_at": str(getattr(spec, "reviewed_at", "") or ""),
"object_id": getattr(spec, "object_id", None),
}
def _coerce_history_record(self, spec, record):
if isinstance(record, dict):
coerced = dict(record)
else:
values = list(record)
coerced = {
"objname": values[0] if len(values) > 0 else getattr(spec, "objname", "Unknown"),
"ra": values[1] if len(values) > 1 else getattr(spec, "ra", np.nan),
"dec": values[2] if len(values) > 2 else getattr(spec, "dec", np.nan),
"class_vi": values[3] if len(values) > 3 else "",
"z_vi": values[4] if len(values) > 4 else getattr(spec, "z_vi", 0.0),
"targetid": values[5] if len(values) > 5 else getattr(spec, "targetid", None),
"data_release": values[6] if len(values) > 6 else getattr(spec, "data_release", None),
"qa_flag": values[7] if len(values) > 7 else getattr(spec, "qa_flag", 0),
"notes": values[8] if len(values) > 8 else "",
"reviewer": values[9] if len(values) > 9 else "",
"reviewed_at": values[10] if len(values) > 10 else "",
"object_id": getattr(spec, "object_id", None),
}
coerced["class_vi"] = normalize_class_label(coerced.get("class_vi", ""))
coerced["qa_flag"] = int(coerced.get("qa_flag", 0) or 0)
coerced["notes"] = str(coerced.get("notes", "") or "")
coerced["reviewer"] = str(coerced.get("reviewer", "") or self._reviewer_default())
coerced["reviewed_at"] = str(coerced.get("reviewed_at", "") or "")
if coerced.get("targetid", None) is None:
coerced["targetid"] = getattr(spec, "targetid", None)
if coerced.get("data_release", None) is None:
coerced["data_release"] = getattr(spec, "data_release", None)
coerced["data_release"] = normalize_data_release(coerced.get("data_release", None), aimsz_review=self._is_aimsz_review)
if coerced.get("objname", None) in (None, "", "nan"):
coerced["objname"] = getattr(spec, "objname", "Unknown")
if coerced.get("ra", None) is None:
coerced["ra"] = getattr(spec, "ra", np.nan)
if coerced.get("dec", None) is None:
coerced["dec"] = getattr(spec, "dec", np.nan)
if coerced.get("object_id", None) is None:
coerced["object_id"] = getattr(spec, "object_id", None)
return coerced
def _history_record(self, spec=None, objid=None, create=False):
if objid is None and spec is not None:
objid = getattr(spec, "objid", None)
if objid is None:
return None
record = self.history.get(objid)
if record is None:
if not create or spec is None:
return None
record = self._new_history_record(spec)
self.history[objid] = record
return record
if isinstance(record, dict):
return record
if spec is not None:
record = self._coerce_history_record(spec, record)
self.history[objid] = record
return record
def _apply_history_to_spec(self, spec):
record = self._history_record(spec=spec, create=False)
if record is None:
self._prime_initial_redshift(spec)
return None
spec.z_vi = record.get("z_vi", getattr(spec, "z_vi", 0.0))
spec.qa_flag = int(record.get("qa_flag", getattr(spec, "qa_flag", 0)) or 0)
spec.class_vi = record.get("class_vi", getattr(spec, "class_vi", ""))
if self._is_aimsz_review:
spec.notes = record.get("notes", "")
spec.reviewer = record.get("reviewer", self._reviewer_default())
spec.reviewed_at = record.get("reviewed_at", "")
return record
def _set_classification(self, spec, class_vi, z_vi=None):
record = self._history_record(spec=spec, create=True)
record["class_vi"] = normalize_class_label(class_vi)
record["z_vi"] = getattr(spec, "z_vi", 0.0) if z_vi is None else z_vi
record["qa_flag"] = int(getattr(spec, "qa_flag", 0) or 0)
if self._is_aimsz_review:
record["reviewer"] = str(getattr(spec, "reviewer", record.get("reviewer", "")) or self._reviewer_default())
spec.reviewer = record["reviewer"]
record["targetid"] = record.get("targetid", getattr(spec, "targetid", None))
record["data_release"] = normalize_data_release(
record.get("data_release", getattr(spec, "data_release", None)),
aimsz_review=self._is_aimsz_review,
)
record["objname"] = record.get("objname", getattr(spec, "objname", "Unknown"))
record["ra"] = getattr(spec, "ra", record.get("ra", np.nan))
record["dec"] = getattr(spec, "dec", record.get("dec", np.nan))
if self._is_aimsz_review:
spec.class_vi = record["class_vi"]
self.history[spec.objid] = record
return record
def _history_rows_for_csv(self):
rows = []
for objid_key, raw_record in self.history.items():
record = dict(raw_record) if isinstance(raw_record, dict) else self._coerce_history_record(self.spec, raw_record)
if self._is_aimsz_review:
rows.append(
{
"objid": objid_key,
"targetid": record.get("targetid"),
"ra": record.get("ra"),
"dec": record.get("dec"),
"data_release": record.get("data_release"),
"class_vi": normalize_class_label(record.get("class_vi", "")),
"z_vi": record.get("z_vi"),
"qa_flag": int(record.get("qa_flag", 0) or 0),
"notes": record.get("notes", ""),
"reviewer": record.get("reviewer", "") or self._reviewer_default(),
"reviewed_at": record.get("reviewed_at", ""),
}
)
else:
rows.append(
{
"objid": objid_key,
"objname": record.get("objname", "Unknown"),
"ra": record.get("ra"),
"dec": record.get("dec"),
"class_vi": normalize_class_label(record.get("class_vi", "")),
"z_vi": record.get("z_vi"),
"targetid": record.get("targetid"),
"data_release": record.get("data_release"),
"qa_flag": int(record.get("qa_flag", 0) or 0),
}
)
return rows
# Copy all existing methods from original PGSpecPlot
@staticmethod
def _count_hdus_in_file(filename):
try:
if hasattr(SpecEuclid1d, "count_in_file"):
return int(SpecEuclid1d.count_in_file(filename))
except Exception:
pass
with fits.open(filename) as hdul:
return max(0, len(hdul) - 1)
def _set_legend(self, enabled):
if self._legend is not None:
try:
self._legend.scene().removeItem(self._legend)
except Exception:
pass
self._legend = None
if enabled:
self._legend = self.addLegend(offset=(10, 10))
def _load_spec(self, index_zero_based):
"""Load a spectrum by 0-based index from ``spectra``.
Args:
index_zero_based: 0-based spectrum index (0 = first spectrum)
For FITS files: automatically converted to 1-based extension number
"""
if self.speclist is not None:
# For list of files: use 0-based index directly
filename = self.speclist[index_zero_based]
spec = self.SpecClass(filename)
else:
# For multi-extension FITS: convert 0-based index to 1-based extension number
spec = self.SpecClass(self.specfile, ext=index_zero_based + 1)
self._ensure_spec_defaults(spec)
return spec
def _resolve_dual_selector(self, index_zero_based):
if not self._dual_parquet_mode:
ext = self.dual_ext if self.dual_ext is not None else index_zero_based + 1
extname = self.dual_extname
if extname in ("",):
extname = None
return ext, extname
if self.dual_extname not in (None, ""):
return None, str(self.dual_extname)
if self.dual_ext is not None:
pair_index = int(self.dual_ext) - 1
else:
pair_index = int(index_zero_based)
if pair_index < 0 or pair_index >= len(self._dual_pair_keys):
raise IndexError(
f"Dual-arm parquet pair index {pair_index} out of range for {len(self._dual_pair_keys)} shared pairs"
)
return None, self._dual_pair_keys[pair_index]
def _load_dual_spec(self, index_zero_based):
ext, extname = self._resolve_dual_selector(index_zero_based)
dual = SpecEuclid1dDual(
rgs_file=self.dual_rgs_file,
bgs_file=self.dual_bgs_file,
ext=ext,
extname=extname,
good_pixels_only=self.dual_good_pixels_only,
)
spec = dual.rgs if dual.rgs is not None else dual.bgs
if spec is None:
raise RuntimeError("SpecEuclid1dDual returned no arm data.")
self._ensure_spec_defaults(spec)
self._apply_external_redshift_overlay(spec, dual=dual)
return spec, dual
def _load_current_spec(self, index_zero_based):
if self._dual_mode:
spec, dual = self._load_dual_spec(index_zero_based)
self.spec_dual = dual
return spec
self.spec_dual = None
spec = self._load_spec(index_zero_based)
self._apply_external_redshift_overlay(spec)
return spec
def _ensure_spec_defaults(self, spec):
"""Ensure common attributes exist on ``spec``."""
if not hasattr(spec, 'z_vi'):
spec.z_vi = getattr(spec, 'redshift', 0.0)
if not hasattr(spec, 'z_ph'):
spec.z_ph = getattr(spec, 'redshift', 0.0)
if not hasattr(spec, 'z_gaia'):
spec.z_gaia = None
if not hasattr(spec, 'objid'):
spec.objid = self.counter
if not hasattr(spec, 'objname'):
spec.objname = 'Unknown'
if not hasattr(spec, 'qa_flag'):
spec.qa_flag = 0
if self._is_aimsz_review:
if not hasattr(spec, 'class_vi'):
spec.class_vi = ""
if not hasattr(spec, 'notes'):
spec.notes = ""
if not hasattr(spec, 'reviewer'):
spec.reviewer = self._reviewer_default()
elif not getattr(spec, 'reviewer', ""):
spec.reviewer = self._reviewer_default()
if not hasattr(spec, 'reviewed_at'):
spec.reviewed_at = ""
[docs]
def update_slider_and_spin(self):
spec = self.spec
initial_z = spec.z_vi if spec.z_vi > 0 else self.z_min
initial_slider_value = int((1/self.base_z_step) * np.log((1+initial_z)/(1+self.z_min)))
self.slider.blockSignals(True)
self.slider.setValue(initial_slider_value)
self.slider.blockSignals(False)
self.redshiftSpin.blockSignals(True)
self.redshiftSpin.setValue(initial_z)
self.redshiftSpin.blockSignals(False)
def _on_manual_range_changed(self, *_args):
if self._suspend_view_lock_updates:
return
self._capture_current_view_range()
def _capture_current_view_range(self):
try:
x_range, y_range = self.vb.viewRange()
except Exception:
return
if x_range is None or y_range is None:
return
if len(x_range) != 2 or len(y_range) != 2:
return
values = [x_range[0], x_range[1], y_range[0], y_range[1]]
if not all(np.isfinite(values)):
return
self._locked_view_range = (
(float(x_range[0]), float(x_range[1])),
(float(y_range[0]), float(y_range[1])),
)
self._view_lock_active = True
def _clear_view_lock(self):
self._view_lock_active = False
self._locked_view_range = None
def _restore_locked_view(self):
if not self._view_lock_active or self._locked_view_range is None:
return False
try:
x_range, y_range = self._locked_view_range
if not all(np.isfinite([x_range[0], x_range[1], y_range[0], y_range[1]])):
return False
self._suspend_view_lock_updates = True
self.enableAutoRange(axis='x', enable=False)
self.enableAutoRange(axis='y', enable=False)
self.vb.setRange(
xRange=(float(x_range[0]), float(x_range[1])),
yRange=(float(y_range[0]), float(y_range[1])),
padding=0.0,
disableAutoRange=True,
)
return True
except Exception:
return False
finally:
self._suspend_view_lock_updates = False
[docs]
def slider_changed(self, slider_value):
z = np.exp(self.base_z_step * slider_value) * (1+self.z_min) - 1
self.spec.z_vi = z
self.redshiftSpin.blockSignals(True)
self.redshiftSpin.setValue(z)
self.redshiftSpin.blockSignals(False)
self.clear()
self.plot_single(preserve_view=self._view_lock_active)
[docs]
def spin_changed(self, z_value):
self.spec.z_vi = z_value
slider_value = int((1/self.base_z_step) * np.log((1+z_value)/(1+self.z_min)))
self.slider.blockSignals(True)
self.slider.setValue(slider_value)
self.slider.blockSignals(False)
self.clear()
self.plot_single(preserve_view=self._view_lock_active)
[docs]
def plot_single(self, preserve_view=False):
"""Plot the spectrum without template."""
spec = self.spec
dual = self.spec_dual
is_sparcl = self._is_sparcl_like_spec(spec)
is_euclid = getattr(spec, 'telescope', '').lower() == 'euclid'
is_euclid_dual = dual is not None
self._observed_wmin = None
self._observed_wmax = None
self._annotation_wave = None
self._annotation_flux = None
self.flux_sm = None
self._set_legend(is_euclid_dual)
annotation_wave_parts = []
annotation_flux_parts = []
# Follow original code pattern with sigma clipping
wave_full = spec.wave.value if hasattr(spec.wave, 'value') else spec.wave
flux_full = spec.flux.value if hasattr(spec.flux, 'value') else spec.flux
if is_euclid_dual:
dual_data = dual.for_redshift()
rgs = dual_data.get("rgs", {})
bgs = dual_data.get("bgs", {})
wave_rgs_raw = rgs.get("wavelength")
flux_rgs_raw = rgs.get("flux")
wave_bgs_raw = bgs.get("wavelength")
flux_bgs_raw = bgs.get("flux")
wave_rgs = np.asarray(wave_rgs_raw if wave_rgs_raw is not None else [], dtype=float)
flux_rgs = np.asarray(flux_rgs_raw if flux_rgs_raw is not None else [], dtype=float)
wave_bgs = np.asarray(wave_bgs_raw if wave_bgs_raw is not None else [], dtype=float)
flux_bgs = np.asarray(flux_bgs_raw if flux_bgs_raw is not None else [], dtype=float)
wave_rgs_for_xlim = wave_rgs[np.isfinite(wave_rgs)]
wave_bgs_for_xlim = wave_bgs[np.isfinite(wave_bgs)]
finite_r = np.isfinite(wave_rgs) & np.isfinite(flux_rgs)
finite_b = np.isfinite(wave_bgs) & np.isfinite(flux_bgs)
wave_rgs = wave_rgs[finite_r]
flux_rgs = flux_rgs[finite_r]
wave_bgs = wave_bgs[finite_b]
flux_bgs = flux_bgs[finite_b]
if wave_rgs.size > 0:
self.plot(
wave_rgs,
flux_rgs,
pen=pg.mkPen((44, 160, 44, 220), width=2),
name="RGS",
antialias=True,
)
gm_r = getattr(dual.rgs, "good_mask", None) if dual.rgs is not None else None
if gm_r is not None and dual.rgs is not None and len(gm_r) == len(getattr(dual.rgs, "wave", [])):
rw = np.asarray(dual.rgs.wave.value if hasattr(dual.rgs.wave, "value") else dual.rgs.wave, dtype=float)
rf = np.asarray(dual.rgs.flux.value if hasattr(dual.rgs.flux, "value") else dual.rgs.flux, dtype=float)
gm_r = np.asarray(gm_r, dtype=bool) & np.isfinite(rw) & np.isfinite(rf)
if np.any(gm_r):
self.plot(rw[gm_r], rf[gm_r], pen=pg.mkPen((44, 160, 44, 240), width=2), antialias=True)
if wave_bgs.size > 0:
self.plot(
wave_bgs,
flux_bgs,
pen=pg.mkPen((31, 119, 180, 220), width=2),
name="BGS (scaled)",
antialias=True,
)
gm_b = getattr(dual.bgs, "good_mask", None) if dual.bgs is not None else None
if gm_b is not None and dual.bgs is not None and len(gm_b) == len(getattr(dual.bgs, "wave", [])):
bw = np.asarray(dual.bgs.wave.value if hasattr(dual.bgs.wave, "value") else dual.bgs.wave, dtype=float)
bf = np.asarray(dual.bgs.flux.value if hasattr(dual.bgs.flux, "value") else dual.bgs.flux, dtype=float)
bs = float(getattr(dual, "arm_scale_bgs_to_rgs", 1.0))
gm_b = np.asarray(gm_b, dtype=bool) & np.isfinite(bw) & np.isfinite(bf)
if np.any(gm_b):
self.plot(bw[gm_b], bf[gm_b] * bs, pen=pg.mkPen((31, 119, 180, 240), width=2), antialias=True)
parts_wave = []
parts_flux = []
if wave_rgs.size > 0:
parts_wave.append(wave_rgs)
parts_flux.append(flux_rgs)
if wave_bgs.size > 0:
parts_wave.append(wave_bgs)
parts_flux.append(flux_bgs)
if parts_wave:
all_wave = np.concatenate(parts_wave)
all_flux = np.concatenate(parts_flux)
self.wave = all_wave
self.flux = all_flux
self._annotation_wave = all_wave
self._annotation_flux = all_flux
xlim_parts = []
if wave_rgs_for_xlim.size > 0:
xlim_parts.append(wave_rgs_for_xlim)
if wave_bgs_for_xlim.size > 0:
xlim_parts.append(wave_bgs_for_xlim)
if xlim_parts:
all_wave_for_xlim = np.concatenate(xlim_parts)
self._observed_wmin = float(np.nanmin(all_wave_for_xlim))
self._observed_wmax = float(np.nanmax(all_wave_for_xlim))
else:
self._observed_wmin = float(np.nanmin(all_wave))
self._observed_wmax = float(np.nanmax(all_wave))
elif is_sparcl:
idx = (wave_full >= 3800) & (wave_full <= 9800)
wave_full = wave_full[idx]
flux_full = flux_full[idx]
if not is_euclid_dual:
if is_euclid:
finite = np.isfinite(wave_full) & np.isfinite(flux_full)
wave = wave_full[finite]
flux = flux_full[finite]
else:
flux_masked = np.ma.masked_invalid(flux_full)
flux_sigclip = sigma_clip(flux_masked, sigma=10, maxiters=3)
wave = wave_full[~flux_sigclip.mask]
flux = flux_sigclip.data[~flux_sigclip.mask]
self.wave = wave
self.flux = flux
annotation_wave_parts = [np.asarray(wave)]
annotation_flux_parts = [np.asarray(flux)]
if is_sparcl:
if self.downsampling_enabled:
downsample_item = self.plot(wave, flux, pen=pg.mkPen('k', width=2), antialias=True)
downsample_item.setDownsampling(ds=3, auto=self._downsampling_auto, method=self._downsampling_method)
downsample_item.setClipToView(self._downsampling_clip_to_view)
else:
self.plot(wave, flux, pen=pg.mkPen('k', width=1.5), antialias=True)
euclid_object_id = getattr(spec, "euclid_object_id", None)
dr_text = str(getattr(spec, "data_release", "") or "")
is_desi_overlay = "desi" in dr_text.lower()
overlay_good_pen = pg.mkPen((128, 0, 128, 220), width=2) if is_desi_overlay else pg.mkPen((0, 150, 0, 220), width=2)
if self.euclid_fits is not None and euclid_object_id not in (None, "", 0):
euclid_spec = self._load_euclid_overlay(euclid_object_id)
if euclid_spec is not None:
euclid_wave = euclid_spec.wave.value
euclid_flux = euclid_spec.flux.value
euclid_good_mask = getattr(euclid_spec, 'good_mask', None)
scale = 1.0
denom_flux = euclid_flux
if euclid_good_mask is not None and len(euclid_good_mask) == len(euclid_flux):
good = np.asarray(euclid_good_mask, dtype=bool)
good = good & np.isfinite(euclid_wave) & np.isfinite(euclid_flux)
if np.any(good):
denom_flux = euclid_flux[good]
denom = np.nanmedian(np.abs(denom_flux))
numer = np.nanmedian(np.abs(flux))
if np.isfinite(denom) and denom > 0 and np.isfinite(numer) and numer > 0:
scale = numer / denom
euclid_flux_scaled = euclid_flux * scale
self.plot(
euclid_wave,
euclid_flux_scaled,
pen=pg.mkPen((95, 95, 95, 150), width=2),
antialias=True,
)
if euclid_good_mask is not None and len(euclid_good_mask) == len(euclid_flux):
good = np.asarray(euclid_good_mask, dtype=bool)
good = good & np.isfinite(euclid_wave) & np.isfinite(euclid_flux_scaled)
if np.any(good):
self.plot(
euclid_wave[good],
euclid_flux_scaled[good],
pen=overlay_good_pen,
antialias=True,
)
annotation_wave_parts.append(np.asarray(euclid_wave))
annotation_flux_parts.append(np.asarray(euclid_flux_scaled))
try:
self._observed_wmin = float(min(np.nanmin(wave), np.nanmin(euclid_wave)))
self._observed_wmax = float(max(np.nanmax(wave), np.nanmax(euclid_wave)))
except Exception:
self._observed_wmin = float(np.nanmin(wave))
self._observed_wmax = float(np.nanmax(wave))
else:
self._observed_wmin = float(np.nanmin(wave))
self._observed_wmax = float(np.nanmax(wave))
else:
if is_euclid:
self.plot(wave, flux, pen=pg.mkPen((95, 95, 95, 150), width=2), antialias=True)
good_mask = getattr(spec, 'good_mask', None)
if good_mask is not None and len(good_mask) == len(wave_full):
good_mask = np.asarray(good_mask, dtype=bool)
good_mask = good_mask & np.isfinite(wave_full) & np.isfinite(flux_full)
if np.any(good_mask):
wave_good = wave_full[good_mask]
flux_good = flux_full[good_mask]
self.plot(wave_good, flux_good, pen=pg.mkPen((0, 0, 180, 220), width=2), antialias=True)
self.wave = wave_good
self.flux = flux_good
annotation_wave_parts.append(np.asarray(wave))
annotation_flux_parts.append(np.asarray(flux))
self._observed_wmin = float(np.nanmin(wave))
self._observed_wmax = float(np.nanmax(wave))
else:
self.plot(wave, flux, pen='b', symbol='o', symbolSize=4,
symbolPen=None, connect='finite', symbolBrush='k', antialias=True)
self._observed_wmin = float(np.nanmin(wave))
self._observed_wmax = float(np.nanmax(wave))
if annotation_wave_parts and annotation_flux_parts:
self._annotation_wave = np.concatenate(annotation_wave_parts)
self._annotation_flux = np.concatenate(annotation_flux_parts)
# Update labels with proper units
if hasattr(spec, 'flux_unit') and spec.flux_unit is not None:
flux_unit_str = f'Flux ({spec.flux_unit})'
wave_unit_str = 'Wavelength (Å)'
if hasattr(spec, 'wave_unit') and spec.wave_unit is not None:
wave_unit_str = f'Wavelength ({spec.wave_unit})'
# Plot template and always clip it to the observed wavelength span.
template = self.template_manager.get_template(self.template_manager.current_template)
if template is not None:
z_vi = getattr(spec, 'z_vi', getattr(spec, 'redshift', 0.0))
# print(f"Plotting template '{self.template_manager.current_template}' with z_vi={z_vi:.4f}")
if z_vi is None:
z_vi = 0.0
wave_temp = template['wave'] * (1 + z_vi)
flux_temp = template['flux']
if self._observed_wmin is not None and self._observed_wmax is not None:
wmin = float(self._observed_wmin)
wmax = float(self._observed_wmax)
else:
finite = np.isfinite(wave) & np.isfinite(flux)
if np.any(finite):
wmin = float(np.nanmin(wave[finite]))
wmax = float(np.nanmax(wave[finite]))
else:
wmin = float(np.nanmin(wave))
wmax = float(np.nanmax(wave))
idx = (wave_temp >= wmin) & (wave_temp <= wmax)
wave_temp = wave_temp[idx]
flux_temp = flux_temp[idx]
if wave_temp.size > 0 and np.isfinite(np.mean(flux_temp)) and np.mean(flux_temp) != 0:
ref_flux = self.flux if hasattr(self, "flux") and self.flux is not None and len(self.flux) > 0 else flux_temp
flux_temp_scaled = flux_temp / np.mean(flux_temp) * np.abs(np.nanmean(ref_flux)) * 1.5
self.plot(wave_temp, flux_temp_scaled, pen=pg.mkPen(_TEMPLATE_COLOR, width=2), antialias=True)
self._label_template_emission_lines(
wmin=wmin,
wmax=wmax,
z=z_vi,
)
self.setLabel('left', flux_unit_str)
self.setLabel('bottom', wave_unit_str)
# Update info label above plot
self.update_spectrum_info_label()
self._apply_axes(preserve_view=preserve_view, is_sparcl=is_sparcl)
# Coordinate signal emission is now handled in navigation methods
def _apply_axes(self, *, preserve_view=False, is_sparcl=False):
if preserve_view and self._restore_locked_view():
return
self._suspend_view_lock_updates = True
try:
if self._observed_wmin is not None and self._observed_wmax is not None:
self.enableAutoRange(axis='x', enable=False)
self.setXRange(float(self._observed_wmin), float(self._observed_wmax), padding=0.0)
flux_arr = self._annotation_flux if self._annotation_flux is not None else getattr(self, 'flux', None)
fixed_zero_ylim = None
if flux_arr is not None:
flux_arr = np.asarray(flux_arr, dtype=float)
finite_flux = flux_arr[np.isfinite(flux_arr)]
if finite_flux.size > 0 and np.allclose(finite_flux, 0.0, rtol=0.0, atol=0.0):
ylim = 1e-16 if getattr(self.spec, 'flux_unit', None) is not None else 1.0
fixed_zero_ylim = (-ylim, ylim)
if fixed_zero_ylim is not None:
self.enableAutoRange(axis='y', enable=False)
self.setYRange(float(fixed_zero_ylim[0]), float(fixed_zero_ylim[1]), padding=0.0)
elif is_sparcl:
flux_ref = np.asarray(getattr(self, 'flux', []), dtype=float)
finite_flux = flux_ref[np.isfinite(flux_ref)]
if finite_flux.size > 0:
self.enableAutoRange(axis='y', enable=False)
y1, y2 = np.percentile(finite_flux, [0.01, 99.99])
ymin = y1 - 0.05 * (y2 - y1)
ymax = y2 + 0.05 * (y2 - y1)
if ymin == ymax:
pad = abs(ymin) * 0.1 if ymin != 0 else 1.0
ymin -= pad
ymax += pad
self.setYRange(ymin, ymax, padding=0.05)
else:
self.enableAutoRange(axis='y', enable=True)
else:
self.enableAutoRange(axis='y', enable=True)
finally:
self._suspend_view_lock_updates = False
[docs]
def update_spectrum_info_label(self):
"""Update the spectrum info label above the plot."""
if not hasattr(self, 'spec'):
return
spec = self.spec
z_vi = getattr(spec, 'z_vi', 0.0)
z_gaia = getattr(spec, 'z_gaia', None)
objname = getattr(spec, 'objname', 'Unknown')
objid = getattr(spec, 'objid', 'Unknown')
class_vi = None
dual_diag = None
record = self._history_record(spec=spec, create=False)
if record is not None:
class_vi = record.get("class_vi")
if class_vi in (None, ""):
class_vi = getattr(spec, "class_vi", None)
def _fmt_z(label, value, *, hide_zero=True):
if value is None:
return None
try:
v = float(value)
except Exception:
return None
if not np.isfinite(v):
return None
if hide_zero and v == 0.0:
return None
return f"{label} = {v:.4f}"
# Calculate the display number based on which spectrum we're actually showing
# In plot_next: counter gets incremented AFTER plotting, so counter+1 is the display number
# In plot_previous: counter gets decremented BEFORE plotting, so counter is the display number
# We need to determine what spectrum we're actually displaying
if hasattr(self, '_displaying_spectrum_number'):
current_spectrum_number = self._displaying_spectrum_number
else:
# Fallback to counter (this handles template updates and other cases)
current_spectrum_number = self.counter if hasattr(self, 'len_list') else 1
message = f"Spectrum {current_spectrum_number}/{self.len_list}" if hasattr(self, 'len_list') else ""
parts = []
if message:
parts.append(message)
display_objid = getattr(spec, "object_id", objid) if self._is_aimsz_review else objid
parts.append(f"ID: {display_objid}")
if ('SpecSparcl' in globals() and isinstance(spec, SpecSparcl)) or (
'SpecAIMSZReview' in globals() and isinstance(spec, SpecAIMSZReview)
):
targetid = getattr(spec, 'targetid', None)
if targetid not in (None, "", 0):
parts.append(f"targetid: {targetid}")
parts.append(_fmt_z("z_vi", z_vi, hide_zero=False) or "z_vi = -")
if class_vi not in (None, ""):
parts.append(f"class_vi: {display_class_label(class_vi)}")
if not self._is_aimsz_review:
z_ref_str = _fmt_z("z_ref", getattr(spec, "z_ref", None), hide_zero=False)
if z_ref_str is not None:
parts.append(z_ref_str)
if (('SpecSparcl' in globals() and isinstance(spec, SpecSparcl)) or
('SpecAIMSZReview' in globals() and isinstance(spec, SpecAIMSZReview))):
dr = str(getattr(spec, 'data_release', '') or '')
if 'desi' in dr.lower():
parts.append(_fmt_z("z_desi", getattr(spec, 'redshift', None), hide_zero=False) or "z_desi = -")
z_gaia_str = _fmt_z("z_gaia", z_gaia, hide_zero=True)
if z_gaia_str is not None:
parts.append(z_gaia_str)
if self.spec_dual is not None:
dual = self.spec_dual
scale = float(getattr(dual, "arm_scale_bgs_to_rgs", 1.0))
status = str(getattr(dual, "scale_status", "unknown"))
owmin = getattr(dual, "overlap_wmin", np.nan)
owmax = getattr(dual, "overlap_wmax", np.nan)
n_bgs = int(getattr(dual, "overlap_n_bgs", 0))
n_rgs = int(getattr(dual, "overlap_n_rgs", 0))
if np.isfinite(owmin) and np.isfinite(owmax) and owmax > owmin:
dual_diag = (
f"BGS→RGS scale={scale:.4g} ({status}), "
f"overlap={owmin:.1f}-{owmax:.1f} Å, nB={n_bgs}, nR={n_rgs}"
)
else:
dual_diag = f"BGS→RGS scale={scale:.4g} ({status}), overlap=none"
text_content = " ".join(parts)
if dual_diag:
text_content = f"{text_content}\n{dual_diag}"
if hasattr(self, 'spectrum_info_label'):
self.spectrum_info_label.setText(text_content)
[docs]
def plot_template(self):
"""Plot template with current redshift."""
template = self.template_manager.get_template(self.template_manager.current_template)
if template is None:
return
z = getattr(self.spec, 'z_vi', getattr(self.spec, 'redshift', 0.0))
if z is None:
z = 0.0
wave_shifted = template['wave'] * (1 + z)
flux_template = template['flux']
if self._observed_wmin is not None and self._observed_wmax is not None:
wmin = float(self._observed_wmin)
wmax = float(self._observed_wmax)
else:
wave_full = self.spec.wave.value if hasattr(self.spec.wave, 'value') else self.spec.wave
flux_full = self.spec.flux.value if hasattr(self.spec.flux, 'value') else self.spec.flux
finite = np.isfinite(wave_full) & np.isfinite(flux_full)
if np.any(finite):
wmin = float(np.nanmin(wave_full[finite]))
wmax = float(np.nanmax(wave_full[finite]))
else:
wmin = float(np.nanmin(wave_full))
wmax = float(np.nanmax(wave_full))
idx = (wave_shifted >= wmin) & (wave_shifted <= wmax)
wave_shifted = wave_shifted[idx]
flux_template = flux_template[idx]
if wave_shifted.size == 0 or not np.isfinite(np.mean(flux_template)) or np.mean(flux_template) == 0:
return
# Scale template like in original code (unclipped on both sides as requested)
if hasattr(self, 'flux') and len(self.flux) > 0:
# Use the cleaned flux from plot_single for scaling
flux_scaled = flux_template / np.mean(flux_template) * np.abs(self.flux.mean()) * 1.5
else:
# Fallback if cleaned flux not available
spec_flux = self.spec.flux.value if hasattr(self.spec.flux, 'value') else self.spec.flux
flux_scaled = flux_template / np.mean(flux_template) * np.abs(np.nanmean(spec_flux)) * 1.5
# Plot template clipped to the observed wavelength range.
self.plot(wave_shifted, flux_scaled, pen=pg.mkPen(_TEMPLATE_COLOR, width=2), antialias=True)
if self._observed_wmin is not None and self._observed_wmax is not None:
self._label_template_emission_lines(wmin=self._observed_wmin, wmax=self._observed_wmax, z=z)
elif hasattr(self, 'wave') and self.wave is not None and len(self.wave) > 0:
self._label_template_emission_lines(wmin=float(np.nanmin(self.wave)), wmax=float(np.nanmax(self.wave)), z=z)
# Update info label to reflect new redshift
self.update_spectrum_info_label()
# Do NOT auto-range during template updates - preserves x-axis range
def _load_euclid_overlay(self, euclid_object_id):
if euclid_object_id is None:
return None
try:
if isinstance(euclid_object_id, float) and np.isnan(euclid_object_id):
return None
except Exception:
pass
key = euclid_object_id
if isinstance(key, (np.integer, int)):
key = int(key)
elif isinstance(key, (np.floating, float)):
if float(key).is_integer():
key = int(key)
key = str(key).strip()
if not key or key.lower() == "nan":
return None
if key in self._euclid_overlay_cache:
return self._euclid_overlay_cache[key]
try:
sp = SpecEuclid1d(self.euclid_fits, extname=key)
except Exception as e:
print(f"Euclid overlay load failed for extname={key}: {e}")
self._euclid_overlay_cache[key] = None
return None
self._euclid_overlay_cache[key] = sp
return sp
def _label_template_emission_lines(self, *, wmin, wmax, z):
"""Overlay emission-line markers for the template within [wmin, wmax]."""
if not np.isfinite(wmin) or not np.isfinite(wmax) or wmax <= wmin:
return
try:
z = float(z)
except Exception:
return
if not np.isfinite(z) or z < 0:
return
# Use current (cleaned) spectrum flux to place labels near the top.
y_ref = None
if hasattr(self, 'flux') and self.flux is not None and len(self.flux) > 0:
try:
y_ref = float(np.nanmax(self.flux))
except Exception:
y_ref = None
if y_ref is None or not np.isfinite(y_ref):
y_ref = 0.0
y_base = y_ref * 0.92 if y_ref != 0 else 0.0
pen = pg.mkPen((255, 0, 0), width=2, style=Qt.DashLine)
text_color = (80, 80, 80)
lines = _TEMPLATE_EMISSION_LINES
# Stagger labels to reduce overlap.
y_offsets = [0.0, 0.06, 0.12]
k = 0
for name, rest_aa in lines:
x = rest_aa * (1.0 + z)
if x < wmin or x > wmax:
continue
self.addItem(pg.InfiniteLine(pos=x, angle=90, pen=pen, movable=False))
y = y_base * (1.0 - y_offsets[k % len(y_offsets)]) if y_base != 0 else 0.0
label = pg.TextItem(text=str(name), color=text_color, anchor=(0.5, 1.0))
label.setPos(x, y)
self.addItem(label)
k += 1
[docs]
def plot_next(self):
"""Plot next spectrum."""
if self.counter >= self.len_list:
print("No more spectra to plot.")
return
self.clear()
self._clear_view_lock()
while self.counter < self.len_list:
try:
spec = self._load_current_spec(self.counter)
self.spec = spec
break
except Exception as e:
shown_idx = self.counter + 1
print(f"Failed to load spectrum {shown_idx}: {e}")
self.counter += 1
else:
print("No more spectra to plot.")
return
record = self._apply_history_to_spec(spec)
if record is not None:
class_vi = record.get("class_vi", "")
print(f"\tVisual class from history: {class_vi}.")
self.update_slider_and_spin()
# Set the display number before plotting (counter + 1 because we haven't incremented yet)
self._displaying_spectrum_number = self.counter + 1
self.plot_single()
# Emit coordinate change signal for cutout loading
if hasattr(self.spec, 'ra') and hasattr(self.spec, 'dec'):
self.coordinate_changed.emit(self.spec.ra, self.spec.dec)
self.current_spec_changed.emit()
self.counter += 1
[docs]
def plot_previous(self):
"""Plot previous spectrum."""
if self.counter > 1:
self.clear()
self._clear_view_lock()
loaded = False
target = self.counter - 2
while target >= 0:
try:
spec = self._load_current_spec(target)
self.spec = spec
loaded = True
break
except Exception as e:
print(f"Failed to load spectrum {target + 1}: {e}")
target -= 1
if not loaded:
print("No previous spectrum to plot.")
return
record = self._apply_history_to_spec(spec)
if record is not None:
class_vi = record.get("class_vi", "")
print(f"\tVisual class from history: {class_vi}.")
self.counter = target + 1
self.update_slider_and_spin()
# Set the display number before plotting (counter is correct after decrement)
self._displaying_spectrum_number = self.counter
self.plot_single()
# Emit coordinate change signal for cutout loading
if hasattr(self.spec, 'ra') and hasattr(self.spec, 'dec'):
self.coordinate_changed.emit(self.spec.ra, self.spec.dec)
self.current_spec_changed.emit()
else:
print("No previous spectrum to plot.")
[docs]
def jump_to_spectrum(self, index_one_based):
"""Jump directly to a spectrum by 1-based index."""
try:
index_one_based = int(index_one_based)
except Exception:
print(f"Invalid index: {index_one_based}")
return
if index_one_based < 1 or index_one_based > self.len_list:
print(f"Index out of range: {index_one_based}. Valid range is 1..{self.len_list}.")
return
self.clear()
self._clear_view_lock()
target_zero_based = index_one_based - 1
try:
spec = self._load_current_spec(target_zero_based)
except Exception as e:
print(f"Failed to load spectrum {index_one_based}: {e}")
return
self.spec = spec
record = self._apply_history_to_spec(spec)
if record is not None:
class_vi = record.get("class_vi", "")
print(f"\tVisual class from history: {class_vi}.")
self.counter = index_one_based
self.update_slider_and_spin()
self._displaying_spectrum_number = index_one_based
self.plot_single()
if hasattr(self.spec, 'ra') and hasattr(self.spec, 'dec'):
self.coordinate_changed.emit(self.spec.ra, self.spec.dec)
self.current_spec_changed.emit()
[docs]
def change_template(self, template_name):
"""Change current template."""
if template_name in self.template_manager.get_available_templates():
self.template_manager.current_template = template_name
self.clear()
self.plot_single()
def _annotate_at_wave(self, wave_pos):
"""Annotate the nearest plotted point to ``wave_pos``."""
wave_arr = self._annotation_wave if self._annotation_wave is not None else getattr(self, 'wave', None)
flux_arr = self._annotation_flux if self._annotation_flux is not None else getattr(self, 'flux', None)
if wave_arr is None or flux_arr is None:
return
finite = np.isfinite(wave_arr) & np.isfinite(flux_arr)
if not np.any(finite):
return
wave_fin = wave_arr[finite]
flux_fin = flux_arr[finite]
idx = np.abs(wave_fin - wave_pos).argmin()
wave_val = wave_fin[idx]
flux_val = flux_fin[idx]
annotation_text = pg.TextItem(
text="Wavelength: {0:.2f} Flux: {1:.2e}".format(wave_val, flux_val),
anchor=(0, 0), color='r', border='w',
fill=(255, 255, 255, 200))
annotation_text.setFont(QFont("Arial", 18, QFont.Bold))
annotation_text.setPos(wave_val, flux_val)
self.addItem(annotation_text)
print("Wavelength: {0:.2f} Flux: {1:.2e}".format(wave_val, flux_val))
[docs]
def keyPressEvent(self, event):
"""Handle keyboard events."""
spec = self.spec
self._last_committed_objid = None
def _commit_current(class_vi=None, z_vi=None):
if class_vi is not None:
record = self._set_classification(spec, class_vi, z_vi)
else:
record = self._history_record(spec=spec, create=True)
record["z_vi"] = getattr(spec, "z_vi", 0.0) if z_vi is None else z_vi
record["qa_flag"] = int(getattr(spec, "qa_flag", 0) or 0)
self._last_committed_objid = spec.objid
self.record_committed.emit(spec.objid)
return record
if event.key() == Qt.Key_Q:
if spec.objid not in self.history:
_commit_current(self._default_class_token(), spec.z_vi)
else:
_commit_current(None, spec.z_vi)
if self.counter < self.len_list:
self.clear()
self.plot_next()
else:
print("No more spectra to plot.")
# Temp save every 50 spectra like original
if (self.counter-1) % 50 == 0:
print("Saving temp file to csv (n={})...".format(self.counter))
temp_filename = f"vi_temp_{self.counter-1}.csv"
df_new = pd.DataFrame(self._history_rows_for_csv())
df_new.to_csv(temp_filename, index=False)
elif event.key() == Qt.Key_S:
print("\tClass: STAR.")
_commit_current('STAR', 0.0)
self.update_spectrum_info_label()
elif event.key() == Qt.Key_G:
print("\tClass: GALAXY.")
_commit_current('GALAXY', spec.z_vi)
self.update_spectrum_info_label()
elif event.key() == Qt.Key_A:
print("\tClass: QSO.")
_commit_current('QSO', spec.z_vi)
self.update_spectrum_info_label()
elif event.key() == Qt.Key_N:
print("\tClass: QSO(Narrow).")
_commit_current('QSO_NARROW', spec.z_vi)
self.update_spectrum_info_label()
elif event.key() == Qt.Key_B:
print("\tClass: QSO(BAL).")
_commit_current('QSO_BAL', spec.z_vi)
self.update_spectrum_info_label()
elif event.key() == Qt.Key_F:
print("\tClass: QSO(FeLoBAL).")
_commit_current('QSO_FELOBAL', spec.z_vi)
self.update_spectrum_info_label()
elif event.key() == Qt.Key_U:
print("\tClass: UNKNOWN.")
_commit_current('UNKNOWN', 0.0)
self.update_spectrum_info_label()
elif event.key() == Qt.Key_D:
print("\tClass: BAD spectrum.")
_commit_current('BAD', 0.0)
self.update_spectrum_info_label()
elif event.key() == Qt.Key_L:
print("\tClass: LIKELY/Unusual QSO.")
_commit_current('LIKELY_Q', spec.z_vi)
self.update_spectrum_info_label()
if event.key() == Qt.Key_R and not (event.modifiers() & Qt.ControlModifier):
self._clear_view_lock()
self.clear()
self.plot_single()
if event.modifiers() & Qt.ControlModifier:
# Mouse position like original
mouse_pos = self.mapFromGlobal(QCursor.pos())
vb = self.getViewBox()
mouse_pos = vb.mapSceneToView(mouse_pos)
print(mouse_pos)
elif event.key() == Qt.Key_Space:
# Annotate spectrum at mouse position like original
mouse_pos = self.mapFromGlobal(QCursor.pos())
vb = self.getViewBox()
wave_pos = vb.mapSceneToView(mouse_pos).x()
self._annotate_at_wave(wave_pos)
if event.modifiers() & Qt.ControlModifier:
if event.key() == Qt.Key_R:
self.clear()
self._clear_view_lock()
# Reload current spectrum using original logic
try:
spec = self._load_current_spec(self.counter - 1)
except Exception as e:
print(f"Failed to reload spectrum {self.counter}: {e}")
return
self.spec = spec
# For reload, display number is current counter
self._displaying_spectrum_number = self.counter
self.update_slider_and_spin()
self.plot_single()
elif event.key() == Qt.Key_Right:
self.clear()
self.counter = self.len_list - 1
self.plot_next()
elif event.key() == Qt.Key_Left:
self.clear()
self.counter = 0
self.plot_next()
elif event.key() == Qt.Key_B:
self.clear()
self.counter = len(self.history) - 1
self.plot_next()
elif event.key() == Qt.Key_Left:
self.plot_previous()
elif event.key() == Qt.Key_Right:
self.plot_next()
elif event.key() == Qt.Key_M:
mouse_pos = self.mapFromGlobal(QCursor.pos())
self.vb = self.getViewBox()
mouse_pos = self.vb.mapSceneToView(mouse_pos)
print(f"Mouse position - Wavelength: {mouse_pos.x():.2f}, Flux: {mouse_pos.y():.2e}")
elif event.key() == Qt.Key_Space:
mouse_pos = self.mapFromGlobal(QCursor.pos())
self.vb = self.getViewBox()
wave = self.vb.mapSceneToView(mouse_pos).x()
self._annotate_at_wave(wave)
[docs]
class PGSpecPlotAppEnhanced(QApplication):
"""Enhanced standalone application with image cutouts and template switching."""
@staticmethod
def _normalize_objid(value):
"""Normalize objid loaded from CSV.
Keeps integers as int (for legacy FITS workflows) and keeps non-numeric
IDs (e.g. SPARCL UUIDs) as str.
"""
if value is None or (isinstance(value, float) and np.isnan(value)):
return value
if isinstance(value, (np.integer, int)):
return int(value)
if isinstance(value, (np.floating, float)):
if float(value).is_integer():
return int(value)
return str(value)
s = str(value)
try:
return int(s)
except Exception:
return s
@staticmethod
def _normalize_qa_flag(value):
"""Normalize qa_flag loaded from CSV/history."""
if value is None or (isinstance(value, float) and np.isnan(value)):
return 0
try:
return int(value)
except Exception:
return 0
@staticmethod
def _clean_scalar(value):
if value is None:
return None
try:
if pd.isna(value):
return None
except Exception:
pass
if isinstance(value, str):
text = value.strip()
return None if text == "" else text
return value
@classmethod
def _is_aimsz_review_class(cls, spec_class):
try:
return issubclass(spec_class, SpecAIMSZReview)
except Exception:
return False
@classmethod
def _is_sparcl_class(cls, spec_class):
try:
return issubclass(spec_class, SpecSparcl)
except Exception:
return False
@classmethod
def _normalize_loaded_objid(cls, row, *, is_aimsz_review=False, is_sparcl=False, aimsz_lookup=None):
if is_aimsz_review:
object_id = cls._clean_scalar(row.get("object_id", None))
if object_id is not None:
return SpecAIMSZReview._canonical_objid(object_id)
raw_objid = cls._clean_scalar(row.get("objid", None))
if raw_objid is not None:
raw_text = str(raw_objid).strip()
if raw_text.startswith("aimsz:"):
return raw_text
if raw_text.startswith("targetid"):
suffix = raw_text[len("targetid"):].strip()
if suffix:
return SpecAIMSZReview._canonical_objid(suffix)
return SpecAIMSZReview._canonical_objid(raw_text)
targetid = cls._clean_scalar(row.get("targetid", None))
if targetid is not None:
return SpecAIMSZReview._canonical_objid(str(targetid).strip())
return None
if is_sparcl:
sparcl_id = cls._clean_scalar(row.get("sparcl_id", None))
targetid = cls._clean_scalar(row.get("targetid", None))
specid = cls._clean_scalar(row.get("specid", None))
raw_objid = cls._clean_scalar(row.get("objid", None))
if isinstance(raw_objid, str):
raw_text = raw_objid.strip()
if raw_text.startswith(("sparcl:", "targetid:", "specid:", "sparcl-row:")):
return raw_text
return SpecSparcl._canonical_objid(
sparcl_id=sparcl_id,
targetid=targetid if targetid is not None else raw_objid,
specid=specid if specid is not None else raw_objid,
filename=None,
row=0,
)
raw_objid = cls._clean_scalar(row.get("objid", None))
return cls._normalize_objid(raw_objid)
@classmethod
def _row_to_history_record(cls, row, *, is_aimsz_review=False):
return {
"objname": row.get("objname", "Unknown"),
"ra": row.get("ra", np.nan),
"dec": row.get("dec", np.nan),
"class_vi": normalize_class_label(row.get("class_vi", "")),
"z_vi": row.get("z_vi", np.nan),
"targetid": row.get("targetid", None),
"data_release": normalize_data_release(row.get("data_release", None), aimsz_review=is_aimsz_review),
"qa_flag": cls._normalize_qa_flag(row.get("qa_flag", 0)),
"notes": "" if not is_aimsz_review else str(row.get("notes", "") or ""),
"reviewer": "" if not is_aimsz_review else str(row.get("reviewer", "") or default_reviewer_username()),
"reviewed_at": "" if not is_aimsz_review else str(row.get("reviewed_at", "") or ""),
"object_id": row.get("object_id", None),
}
@staticmethod
def _now_reviewed_at():
return datetime.now(timezone.utc).replace(microsecond=0).isoformat()
def __init__(self, spectra, SpecClass=SpecEuclid1d,
output_file='vi_output.csv', z_max=5.0, load_history=False,
euclid_fits=None, cutout_buffer_dir=None, enable_image_panel=None,
enable_background_prefetch=None,
rgs_file=None, bgs_file=None, ext=None, extname=None,
dual_good_pixels_only=False, external_redshift_lookup=None,
external_redshift_key="object_id"):
super().__init__(sys.argv)
self.output_file = output_file
self.spectra = spectra
self.SpecClass = SpecClass
self._is_aimsz_review = self._is_aimsz_review_class(self.SpecClass)
self._is_sparcl = self._is_sparcl_class(self.SpecClass)
self._session_reviewer = default_reviewer_username()
self.euclid_fits = euclid_fits
if enable_image_panel is None:
enable_image_panel = False
self.enable_image_panel = bool(enable_image_panel)
if enable_background_prefetch is None:
enable_background_prefetch = self.enable_image_panel
self.enable_background_prefetch = bool(enable_background_prefetch) and self.enable_image_panel
self.rgs_file = rgs_file
self.bgs_file = bgs_file
self.dual_ext = ext
self.dual_extname = extname
self.dual_good_pixels_only = bool(dual_good_pixels_only)
self._dual_mode = bool(self.rgs_file and self.bgs_file)
self.external_redshift_lookup = dict(external_redshift_lookup or {})
self.external_redshift_key = str(external_redshift_key or "object_id")
if load_history and os.path.exists(self.output_file):
print(f"Loading history from {self.output_file} ...")
df = pd.read_csv(self.output_file)
if 'vi_class' in df.columns:
df.rename(columns={'vi_class': 'class_vi'}, inplace=True)
history_dict = {}
for _, row in df.iterrows():
objid = self._normalize_loaded_objid(
row,
is_aimsz_review=self._is_aimsz_review,
is_sparcl=self._is_sparcl,
)
if objid is None:
continue
record = self._row_to_history_record(
row,
is_aimsz_review=self._is_aimsz_review,
)
history_dict[objid] = record
if self._is_aimsz_review and record.get("reviewer"):
self._session_reviewer = str(record.get("reviewer"))
initial_counter = df.shape[0]
else:
history_dict = {}
initial_counter = 0
self.plot = PGSpecPlotEnhanced(
self.spectra, self.SpecClass,
initial_counter=initial_counter,
z_max=z_max,
history_dict=history_dict,
euclid_fits=self.euclid_fits,
rgs_file=self.rgs_file,
bgs_file=self.bgs_file,
ext=self.dual_ext,
extname=self.dual_extname,
dual_good_pixels_only=self.dual_good_pixels_only,
external_redshift_lookup=self.external_redshift_lookup,
external_redshift_key=self.external_redshift_key)
self.plot._session_reviewer = self._session_reviewer
self.len_list = self.plot.len_list
self.plot.current_spec_changed.connect(self.on_current_spec_changed)
self.plot.record_committed.connect(self.on_record_committed)
self._last_review_panel_objid = None
self._last_review_panel_state = None
self.cutout_widget = None
self.review_panel = None
if self._is_aimsz_review:
self.review_panel = AIMSZReviewPanel(default_reviewer=self._session_reviewer)
self.review_panel.qa_flag_changed.connect(self.on_review_panel_qa_changed)
self.review_panel.notes_changed.connect(self.on_review_notes_changed)
self.review_panel.reviewer_changed.connect(self.on_review_reviewer_changed)
if self.enable_image_panel:
if cutout_buffer_dir is not None:
buffer_dir = Path(cutout_buffer_dir)
elif self._dual_mode:
buffer_dir = Path(self.rgs_file).parent / "cutout_buffer"
elif isinstance(spectra, str): # Single FITS file
buffer_dir = Path(spectra).parent / "cutout_buffer"
else: # List of files
buffer_dir = Path(spectra[0]).parent / "cutout_buffer"
self.cutout_widget = ImageCutoutWidget(buffer_dir=buffer_dir)
if self._is_aimsz_review and hasattr(self.cutout_widget, "qa_group"):
self.cutout_widget.qa_group.setVisible(False)
self.cutout_widget.qa_contamination_cb.stateChanged.connect(self.on_qa_flag_changed)
self.cutout_widget.qa_unusable_cb.stateChanged.connect(self.on_qa_flag_changed)
self.plot.coordinate_changed.connect(self.on_coordinate_changed)
if hasattr(self.plot, 'spec') and hasattr(self.plot.spec, 'ra') and hasattr(self.plot.spec, 'dec'):
objid = getattr(self.plot.spec, 'objid', None)
self.cutout_widget.load_online_cutouts(self.plot.spec.ra, self.plot.spec.dec, objid)
self.sync_qa_checkbox_from_current_spec()
# Start background prefetching for next objects
if self.enable_background_prefetch:
self.start_background_prefetch()
self.make_layout()
self.aboutToQuit.connect(self.save_dict_todf)
[docs]
def on_coordinate_changed(self, ra, dec):
"""Handle coordinate changes and pass object ID to cutout widget."""
if self.cutout_widget is None:
return
objid = getattr(self.plot.spec, 'objid', None) if hasattr(self.plot, 'spec') else None
self.cutout_widget.load_online_cutouts(ra, dec, objid)
self.sync_qa_checkbox_from_current_spec()
[docs]
def on_current_spec_changed(self):
if self._is_aimsz_review and hasattr(self.plot, "spec"):
spec = self.plot.spec
if not getattr(spec, "reviewer", ""):
spec.reviewer = self._session_reviewer
self.sync_qa_checkbox_from_current_spec()
self._update_go_to_controls()
self._update_downsample_toggle()
[docs]
def on_record_committed(self, objid):
if not self._is_aimsz_review:
return
record = self.plot._history_record(objid=objid, create=False)
if record is None:
return
record["reviewed_at"] = self._now_reviewed_at()
if getattr(self.plot, "spec", None) is not None and getattr(self.plot.spec, "objid", None) == objid:
self.plot.spec.reviewed_at = record["reviewed_at"]
self._sync_review_panel(force=True)
def _commit_current_review_state(self, update_timestamp=False):
if not self._is_aimsz_review or not hasattr(self.plot, "spec"):
return
spec = self.plot.spec
record = self.plot._history_record(spec=spec, create=True)
record["z_vi"] = getattr(spec, "z_vi", record.get("z_vi", 0.0))
record["qa_flag"] = int(getattr(spec, "qa_flag", record.get("qa_flag", 0)) or 0)
record["notes"] = str(getattr(spec, "notes", record.get("notes", "")) or "")
record["reviewer"] = str(
getattr(spec, "reviewer", record.get("reviewer", "")) or self._session_reviewer
)
self._session_reviewer = record["reviewer"]
self.plot._session_reviewer = self._session_reviewer
if self.review_panel is not None:
self.review_panel.set_default_reviewer(self._session_reviewer)
spec.reviewer = record["reviewer"]
if not record.get("class_vi"):
record["class_vi"] = self.plot._default_class_token()
if update_timestamp:
record["reviewed_at"] = self._now_reviewed_at()
spec.reviewed_at = record["reviewed_at"]
self.plot.history[spec.objid] = record
self._sync_review_panel(force=True)
def _review_panel_state(self, spec, record):
if spec is None:
return None
record = record or {}
return (
getattr(spec, "objid", None),
int(record.get("qa_flag", getattr(spec, "qa_flag", 0)) or 0),
str(record.get("notes", getattr(spec, "notes", "")) or ""),
str(record.get("reviewer", getattr(spec, "reviewer", "")) or ""),
str(record.get("reviewed_at", getattr(spec, "reviewed_at", "")) or ""),
getattr(spec, "review_priority_tier", None),
getattr(spec, "review_score", None),
getattr(spec, "review_rank_within_tier", None),
getattr(spec, "review_slice_label", None),
getattr(spec, "z_ref", None),
getattr(spec, "z_ml_expect", None),
getattr(spec, "z_pcf_best", None),
getattr(spec, "pcf_template_best", None),
getattr(spec, "pcf_score_best", None),
)
def _sync_review_panel(self, force=False):
if self.review_panel is None or not hasattr(self.plot, "spec"):
return
spec = self.plot.spec
record = self.plot._history_record(spec=spec, objid=getattr(spec, "objid", None), create=False)
state = self._review_panel_state(spec, record)
if not force and state == self._last_review_panel_state:
return
self.review_panel.set_review_context(spec, record)
self._last_review_panel_objid = getattr(spec, "objid", None)
self._last_review_panel_state = state
[docs]
def sync_qa_checkbox_from_current_spec(self):
"""Restore QA checkboxes from current spec/history."""
if not hasattr(self.plot, 'spec'):
return
spec = self.plot.spec
objid = getattr(spec, 'objid', None)
record = self.plot._history_record(spec=spec, objid=objid, create=False)
qa_flag = self._normalize_qa_flag(record.get("qa_flag", getattr(spec, 'qa_flag', 0)) if record else getattr(spec, 'qa_flag', 0))
spec.qa_flag = qa_flag
if self.cutout_widget is not None:
self.cutout_widget.set_qa_flag(qa_flag)
self._sync_review_panel(force=False)
[docs]
def on_qa_flag_changed(self, _state):
"""Persist QA flag from checkbox selections into current spec/history."""
if self.cutout_widget is None:
return
if not hasattr(self.plot, 'spec'):
return
spec = self.plot.spec
qa_flag = self.cutout_widget.get_qa_flag()
spec.qa_flag = qa_flag
objid = getattr(spec, 'objid', None)
if objid is None:
return
record = self.plot._history_record(spec=spec, objid=objid, create=True)
if not record.get("class_vi"):
record["class_vi"] = self.plot._default_class_token()
record["qa_flag"] = qa_flag
self.plot.history[objid] = record
self._sync_review_panel(force=True)
[docs]
def on_review_panel_qa_changed(self, qa_flag):
if not hasattr(self.plot, "spec"):
return
spec = self.plot.spec
spec.qa_flag = int(qa_flag)
objid = getattr(spec, "objid", None)
if objid is None:
return
record = self.plot._history_record(spec=spec, objid=objid, create=True)
if not record.get("class_vi"):
record["class_vi"] = self.plot._default_class_token()
record["qa_flag"] = int(qa_flag)
self.plot.history[objid] = record
if self.cutout_widget is not None:
self.cutout_widget.set_qa_flag(int(qa_flag))
[docs]
def on_review_notes_changed(self, text):
if not hasattr(self.plot, "spec"):
return
spec = self.plot.spec
spec.notes = text
record = self.plot._history_record(spec=spec, create=True)
record["notes"] = text
if not record.get("class_vi"):
record["class_vi"] = self.plot._default_class_token()
self.plot.history[spec.objid] = record
[docs]
def on_review_reviewer_changed(self, text):
if not hasattr(self.plot, "spec"):
return
spec = self.plot.spec
reviewer = str(text or "").strip() or self._session_reviewer
spec.reviewer = reviewer
self._session_reviewer = reviewer
self.plot._session_reviewer = self._session_reviewer
if self.review_panel is not None:
self.review_panel.set_default_reviewer(self._session_reviewer)
record = self.plot._history_record(spec=spec, create=True)
record["reviewer"] = reviewer
if not record.get("class_vi"):
record["class_vi"] = self.plot._default_class_token()
self.plot.history[spec.objid] = record
self._sync_review_panel(force=False)
[docs]
def start_background_prefetch(self):
"""Start background prefetching of cutouts for upcoming spectra."""
def prefetch_worker():
try:
# Load upcoming spectra data for prefetching
current_index = getattr(self.plot, 'counter', 0)
spectra_data = []
# Get ALL remaining spectra data
for i in range(current_index, self.len_list):
try:
if self._dual_mode:
dual = SpecEuclid1dDual(
rgs_file=self.rgs_file,
bgs_file=self.bgs_file,
ext=self.dual_ext if self.dual_ext is not None else i + 1,
extname=self.dual_extname,
good_pixels_only=self.dual_good_pixels_only,
)
spec = dual.rgs if dual.rgs is not None else dual.bgs
if spec is None:
continue
elif self.plot.speclist is not None:
# Load from individual files
filename = self.plot.speclist[i]
spec = self.SpecClass(filename)
else:
# Load from multi-extension FITS
spec = self.SpecClass(self.spectra, ext=i + 1)
self.plot._ensure_spec_defaults(spec)
# Only prefetch if has coordinates
if hasattr(spec, 'ra') and hasattr(spec, 'dec') and hasattr(spec, 'objid'):
spectra_data.append(spec)
except Exception as e:
print(f"Error loading spectrum {i} for prefetch: {e}")
continue
# Trigger prefetching if we have data
if spectra_data:
self.cutout_widget.prefetch_cutouts_background(spectra_data, 0, None)
except Exception as e:
print(f"Error in background prefetch setup: {e}")
# Run in background thread
threading.Thread(target=prefetch_worker, daemon=True).start()
def _update_downsample_toggle(self):
if not hasattr(self, "downsample_toggle"):
return
enabled = self.plot._is_current_sparcl_like()
self.downsample_toggle.blockSignals(True)
self.downsample_toggle.setChecked(bool(self.plot.downsampling_enabled))
self.downsample_toggle.setEnabled(enabled)
self.downsample_toggle.blockSignals(False)
[docs]
def on_downsample_toggled(self, checked):
self.plot.downsampling_enabled = bool(checked)
if not hasattr(self.plot, "spec") or not self.plot._is_current_sparcl_like():
return
self.plot.clear()
self.plot.plot_single(preserve_view=self.plot._view_lock_active)
self.plot.setFocus()
[docs]
def make_layout(self):
"""Create the enhanced layout with image cutouts and controls."""
layout = pg.LayoutWidget()
layout.resize(1300, 900) # Reasonable width with text wrapping
layout.setWindowTitle(f"PGSpecPlot Enhanced - Spectra Viewer (v{viewer_version})")
if self.plot.counter < self.len_list + 1:
# Create toolbar with template controls and save buttons
toolbar = QWidget()
toolbar_layout = QHBoxLayout()
# Template selection
template_group = QGroupBox("Template")
template_group.setMaximumHeight(45) # Make template group more compact
template_layout = QHBoxLayout()
template_layout.setContentsMargins(5, 2, 5, 2) # Reduce margins
self.template_buttons = QButtonGroup()
template_names = self.plot.template_manager.get_available_templates()
for i, template_name in enumerate(template_names):
btn = QRadioButton(template_name)
if template_name == "Type 1":
btn.setChecked(True)
btn.clicked.connect(lambda checked, name=template_name:
self.plot.change_template(name) if checked else None)
self.template_buttons.addButton(btn)
template_layout.addWidget(btn)
template_group.setLayout(template_layout)
toolbar_layout.addWidget(template_group)
if self.enable_image_panel and self.cutout_widget is not None:
self.image_toggle_btn = QPushButton("Hide Images")
self.image_toggle_btn.clicked.connect(self.toggle_image_panel)
self.image_toggle_btn.setCheckable(True)
self.image_toggle_btn.setMaximumHeight(35) # Make button more compact
toolbar_layout.addWidget(self.image_toggle_btn)
self.downsample_toggle = QCheckBox("Downsample (n=3)")
self.downsample_toggle.setMaximumHeight(35)
self.downsample_toggle.setToolTip(
"Use pyqtgraph native downsampling for SPARCL/AIMS-z spectra "
"(mean + auto + clip-to-view)."
)
self.downsample_toggle.stateChanged.connect(self.on_downsample_toggled)
toolbar_layout.addWidget(self.downsample_toggle)
toolbar_layout.addWidget(QLabel("Go to index:"))
self.goto_index_spin = QSpinBox()
self.goto_index_spin.setMinimum(1)
self.goto_index_spin.setMaximum(self.len_list)
self.goto_index_spin.setValue(1)
self.goto_index_spin.setMaximumHeight(35)
# Avoid stealing keyboard navigation shortcuts on startup.
self.goto_index_spin.setFocusPolicy(Qt.ClickFocus)
toolbar_layout.addWidget(self.goto_index_spin)
self.goto_index_btn = QPushButton("Go")
self.goto_index_btn.setMaximumHeight(35)
self.goto_index_btn.clicked.connect(self.go_to_index)
toolbar_layout.addWidget(self.goto_index_btn)
# Add spacer
toolbar_layout.addStretch()
# Save buttons
self.save_png_btn = QPushButton("Save PNG")
self.save_png_btn.clicked.connect(self.save_png)
self.save_png_btn.setMaximumHeight(35)
toolbar_layout.addWidget(self.save_png_btn)
self.save_btn = QPushButton("Save")
self.save_btn.clicked.connect(self.save_data)
self.save_btn.setMaximumHeight(35) # Make button more compact
toolbar_layout.addWidget(self.save_btn)
self.save_quit_btn = QPushButton("Save & Quit")
self.save_quit_btn.clicked.connect(self.save_and_quit)
self.save_quit_btn.setMaximumHeight(35) # Make button more compact
toolbar_layout.addWidget(self.save_quit_btn)
toolbar.setLayout(toolbar_layout)
toolbar.setMaximumHeight(50) # Limit toolbar height
toolbar.setMinimumHeight(40) # Set minimum height
layout.addWidget(toolbar, row=0, col=0, colspan=2)
# Instructions with comprehensive keyboard shortcuts
instruction_text = (
"Navigation: 'Q' next spectrum, Left/Right arrows previous/next | "
"Classification: 'A' QSO, 'N' QSO(Narrow), 'B' QSO(BAL), 'F' QSO(FeLoBAL), 'D' BAD spectrum, | "
"'S' STAR, 'G' GALAXY, 'U' UNKNOWN, 'L' LIKELY_Q | "
"Tools: 'Space' wavelength info, 'M' mouse position, 'R' reset zoom | "
"Advanced: Ctrl+R reload, Ctrl+Left first, Ctrl+Right last, Ctrl+B resume from history"
)
toplabel = layout.addLabel(instruction_text, row=1, col=0, colspan=2)
toplabel.setFont(QFont("Arial", 15))
toplabel.setMinimumHeight(72)
toplabel.setMaximumHeight(96)
toplabel.setAlignment(Qt.AlignLeft | Qt.AlignTop)
toplabel.setStyleSheet("background-color: white;color: black;")
toplabel.setFrameStyle(QFrame.Panel | QFrame.Raised)
toplabel.setWordWrap(True)
# Main content area with splitter
main_splitter = QSplitter(Qt.Horizontal)
# Left side: spectrum plot and controls
left_widget = QWidget()
left_layout = QVBoxLayout()
# Add spectrum info label above plot
left_layout.addWidget(self.plot.spectrum_info_label)
left_layout.addWidget(self.plot)
# Redshift slider
slider_container = QWidget()
slider_layout = QHBoxLayout()
slider_layout.addWidget(QLabel("Redshift:"))
slider_layout.addWidget(self.plot.slider)
slider_layout.addWidget(self.plot.redshiftSpin)
slider_container.setLayout(slider_layout)
left_layout.addWidget(slider_container)
left_widget.setLayout(left_layout)
main_splitter.addWidget(left_widget)
right_widget = None
if self.review_panel is not None or (self.enable_image_panel and self.cutout_widget is not None):
right_widget = QWidget()
right_layout = QVBoxLayout()
if self.review_panel is not None:
right_layout.addWidget(self.review_panel)
if self.enable_image_panel and self.cutout_widget is not None:
right_layout.addWidget(self.cutout_widget)
right_layout.addStretch()
right_widget.setLayout(right_layout)
main_splitter.addWidget(right_widget)
main_splitter.setSizes(self._default_splitter_sizes())
main_splitter.setStretchFactor(0, 5 if self._is_aimsz_review else 4)
main_splitter.setStretchFactor(1, 1)
main_splitter.setCollapsible(1, True)
else:
main_splitter.setSizes([1000])
layout.addWidget(main_splitter, row=2, col=0, colspan=2)
self.main_splitter = main_splitter # Store reference for toggle
self._update_go_to_controls()
# Keep keyboard shortcuts active by default.
self.plot.setFocusPolicy(Qt.StrongFocus)
self.plot.setFocus()
self.layout = layout
self.layout.show()
self._update_downsample_toggle()
def _default_splitter_sizes(self):
if self._is_aimsz_review:
if self.enable_image_panel and self.cutout_widget is not None:
return [900, 280]
return [980, 220]
if self.review_panel is not None and self.cutout_widget is not None:
return [780, 320]
if self.review_panel is not None:
return [860, 280]
if self.cutout_widget is not None:
return [800, 200]
return [1000]
[docs]
def keyPressEvent(self, event):
"""Forward keyboard events to plot widget."""
self.plot.keyPressEvent(event)
self.sync_qa_checkbox_from_current_spec()
self._update_go_to_controls()
self.plot.setFocus()
def _current_index_one_based(self):
if hasattr(self.plot, "_displaying_spectrum_number"):
return int(self.plot._displaying_spectrum_number)
counter = int(getattr(self.plot, "counter", 1))
if counter < 1:
return 1
if counter > self.len_list:
return self.len_list
return counter
def _update_go_to_controls(self):
if not hasattr(self, "goto_index_spin"):
return
current_idx = self._current_index_one_based()
self.goto_index_spin.blockSignals(True)
self.goto_index_spin.setValue(current_idx)
self.goto_index_spin.blockSignals(False)
[docs]
def go_to_index(self):
if not hasattr(self, "goto_index_spin"):
return
index_one_based = int(self.goto_index_spin.value())
self.plot.jump_to_spectrum(index_one_based)
self.sync_qa_checkbox_from_current_spec()
self._update_go_to_controls()
self.plot.setFocus()
[docs]
def mousePressEvent(self, event):
"""Forward mouse events to plot widget."""
self.plot.mousePressEvent(event)
[docs]
def save_data(self):
"""Save current data to CSV."""
self.save_dict_todf()
QMessageBox.information(self.layout, "Saved", f"Data saved to {self.output_file}")
[docs]
def save_png(self):
"""Save the entire application window as a PNG image."""
def _sanitize_component(value):
s = str(value).strip()
s = s.replace(os.sep, "_").replace(" ", "_")
s = re.sub(r"[^0-9A-Za-z_.-]+", "_", s)
return s.strip("_") or "unknown"
def _infer_survey(spec):
dr = str(getattr(spec, "data_release", "") or getattr(spec, "_dr", "") or "")
dr_l = dr.lower()
if "desi" in dr_l:
return "desi"
if "sdss" in dr_l or "boss" in dr_l or "eboss" in dr_l:
return "sdss"
if "lamost" in dr_l:
return "lamost"
if "gaia" in dr_l:
return "gaia"
if dr_l:
# keep it short and filesystem-friendly
return _sanitize_component(dr_l)[:24]
return "sparcl"
spec = self.plot.spec if hasattr(self.plot, "spec") else None
objid = getattr(spec, "objid", "spectrum") if spec is not None else "spectrum"
objid_str = _sanitize_component(objid)
if spec is not None and getattr(spec, "telescope", "").lower() == "euclid":
base_name = f"euclid_{objid_str}_vi.png"
elif ((("SpecSparcl" in globals()) and isinstance(spec, SpecSparcl)) or
(("SpecAIMSZReview" in globals()) and isinstance(spec, SpecAIMSZReview))):
survey = _infer_survey(spec)
targetid = getattr(spec, "targetid", None)
if targetid not in (None, "", 0):
base_name = f"{survey}_{_sanitize_component(targetid)}_vi.png"
else:
base_name = f"{survey}_{objid_str}_vi.png"
else:
base_name = f"{objid_str}_vi.png"
out_dir = Path.cwd() / "saved_pngs"
try:
out_dir.mkdir(parents=True, exist_ok=True)
except Exception as e:
QMessageBox.warning(self.layout, "Save PNG failed", f"Could not create {out_dir}: {e}")
return
filename = out_dir / base_name
if filename.exists():
i = 2
while True:
stem = filename.stem
candidate = out_dir / f"{stem}_{i}.png"
if not candidate.exists():
filename = candidate
break
i += 1
try:
# Ensure latest visuals are painted before grabbing.
self.layout.repaint()
QApplication.processEvents()
pixmap = self.layout.grab()
ok = pixmap.save(str(filename), "PNG")
except Exception as e:
QMessageBox.warning(self.layout, "Save PNG failed", str(e))
return
if not ok:
QMessageBox.warning(self.layout, "Save PNG failed", "Qt failed to write the PNG file.")
return
QMessageBox.information(self.layout, "Saved", f"Saved PNG to {filename}")
[docs]
def save_and_quit(self):
"""Save data and quit application."""
self.save_dict_todf()
QMessageBox.information(self.layout, "Saved", f"Data saved to {self.output_file}")
self.quit()
[docs]
def toggle_image_panel(self):
"""Toggle visibility of the image cutout panel."""
if not self.enable_image_panel or self.cutout_widget is None:
return
if self.image_toggle_btn.isChecked():
if hasattr(self, "main_splitter"):
self.main_splitter.setSizes([1000, 0])
self.cutout_widget.setVisible(False)
self.image_toggle_btn.setText("Show Images")
self.cutout_widget.auto_fetch_cb.setChecked(False)
else:
self.cutout_widget.setVisible(True)
if hasattr(self, "main_splitter"):
self.main_splitter.setSizes(self._default_splitter_sizes())
self.image_toggle_btn.setText("Hide Images")
self.cutout_widget.auto_fetch_cb.setChecked(True)
[docs]
def run_cross_correlation(self):
"""Run cross-correlation analysis - placeholder."""
QMessageBox.information(self, "Cross-Correlation",
"Cross-correlation feature is coming soon!\n\n"
"This will perform automatic redshift measurement\n"
"using template cross-correlation.")
[docs]
def save_dict_todf(self):
"""Save classification results to CSV."""
if not self.plot.history:
return
self._commit_current_review_state(update_timestamp=self.plot._is_aimsz_review)
rows = self.plot._history_rows_for_csv()
df_new = pd.DataFrame(rows)
if self.plot._is_aimsz_review:
ordered_cols = [
"objid",
"targetid",
"ra",
"dec",
"data_release",
"class_vi",
"z_vi",
"qa_flag",
"notes",
"reviewer",
"reviewed_at",
]
df_new = df_new.reindex(columns=ordered_cols)
df_new.to_csv(self.output_file, index=False)
print(f"Results saved to {self.output_file}")
[docs]
class PGSpecPlotThreadEnhanced(QThread):
"""Enhanced thread wrapper for the enhanced application."""
def __init__(self, spectra=None, SpecClass=SpecEuclid1d, specfile=None, **kwargs):
super().__init__()
explicit_buffer_dir = kwargs.pop("cutout_buffer_dir", None)
explicit_enable_image_panel = "enable_image_panel" in kwargs
if not explicit_enable_image_panel:
kwargs["enable_image_panel"] = False
self.enable_image_panel = bool(kwargs.get("enable_image_panel", True))
self.rgs_file = kwargs.get("rgs_file", None)
self.bgs_file = kwargs.get("bgs_file", None)
self.dual_ext = kwargs.get("ext", None)
self.dual_extname = kwargs.get("extname", None)
self.dual_good_pixels_only = bool(kwargs.get("dual_good_pixels_only", False))
self._dual_mode = bool(self.rgs_file and self.bgs_file)
self._dual_parquet_mode = self._dual_mode and _is_dataframe_backed_spectrum_path(self.rgs_file) and _is_dataframe_backed_spectrum_path(self.bgs_file)
self._dual_pair_keys = []
# Handle backward compatibility: if specfile is provided but spectra is not
if spectra is None and specfile is not None:
self.spectra = specfile
elif spectra is not None:
self.spectra = spectra
elif self._dual_mode:
self.spectra = self.rgs_file
else:
raise ValueError("Either 'spectra' or 'specfile' must be provided")
self.SpecClass = SpecClass
self.app = None
self._skip_window = False
self._disable_background_prefetch = not self.enable_image_panel
self.buffer_dir = None
if self._dual_parquet_mode:
self._dual_pair_keys, _ = _ordered_shared_dual_pair_keys(self.rgs_file, self.bgs_file)
if self.enable_image_panel:
self.buffer_dir = Path(explicit_buffer_dir) if explicit_buffer_dir else self._resolve_buffer_dir(self.spectra)
if self._should_offer_predownload(self.buffer_dir):
if self._prompt_for_bulk_download():
self._skip_window = self._run_bulk_predownload()
if not self._skip_window:
if "enable_background_prefetch" not in kwargs:
kwargs["enable_background_prefetch"] = not self._disable_background_prefetch
self.app = PGSpecPlotAppEnhanced(
self.spectra,
self.SpecClass,
cutout_buffer_dir=self.buffer_dir,
**kwargs,
)
@staticmethod
def _resolve_buffer_dir(spectra):
"""Resolve cutout buffer path from input spectra."""
if isinstance(spectra, str):
return Path(spectra).parent / "cutout_buffer"
return Path(spectra[0]).parent / "cutout_buffer"
@staticmethod
def _count_hdus_in_file(filename):
try:
if hasattr(SpecEuclid1d, "count_in_file"):
return int(SpecEuclid1d.count_in_file(filename))
except Exception:
pass
with fits.open(filename) as hdul:
return max(0, len(hdul) - 1)
@staticmethod
def _should_offer_predownload(buffer_dir):
"""Offer predownload only when cutout buffer does not exist yet."""
return buffer_dir is not None and not Path(buffer_dir).exists()
@staticmethod
def _prompt_for_bulk_download():
"""Prompt user in CLI to decide whether to pre-download all cutouts."""
if not sys.stdin or not sys.stdin.isatty():
print("No interactive terminal detected; continuing with on-the-fly cutout download.")
return False
while True:
answer = input(
"No 'cutout_buffer' folder found. Download all cutouts before launching the Qt window? [y/N]: "
).strip().lower()
if answer in ("y", "yes"):
return True
if answer in ("", "n", "no"):
return False
print("Please answer 'y' or 'n'.")
def _collect_cutout_records(self):
"""Collect objid/ra/dec records for all spectra."""
records = []
if self._dual_mode:
if self._dual_parquet_mode:
if self.dual_extname not in (None, ""):
selectors = [(None, str(self.dual_extname))]
elif self.dual_ext is not None:
pair_index = int(self.dual_ext) - 1
if pair_index < 0 or pair_index >= len(self._dual_pair_keys):
return records
selectors = [(None, self._dual_pair_keys[pair_index])]
else:
selectors = [(None, key) for key in self._dual_pair_keys]
elif self.dual_ext is not None or self.dual_extname not in (None, ""):
selectors = [(self.dual_ext if self.dual_ext is not None else None, self.dual_extname)]
else:
try:
total = self._count_hdus_in_file(self.rgs_file)
except Exception as exc:
print(f"Failed to determine dual spectrum count for predownload: {exc}")
return records
selectors = [(ext, self.dual_extname) for ext in range(1, total + 1)]
for ext, extname in selectors:
try:
dual = SpecEuclid1dDual(
rgs_file=self.rgs_file,
bgs_file=self.bgs_file,
ext=ext,
extname=extname,
good_pixels_only=self.dual_good_pixels_only,
)
spec = dual.rgs if dual.rgs is not None else dual.bgs
if spec is None:
continue
objid = getattr(spec, "objid", None)
ra = getattr(spec, "ra", None)
dec = getattr(spec, "dec", None)
records.append({"objid": objid, "ra": ra, "dec": dec})
except Exception as exc:
selector_desc = extname if extname not in (None, "") else ext
print(f"Failed to load dual spectrum selector={selector_desc} for predownload: {exc}")
return records
if isinstance(self.spectra, (list, tuple, np.ndarray)):
input_list = list(self.spectra)
for filename in input_list:
try:
spec = self.SpecClass(filename)
objid = getattr(spec, "objid", None)
ra = getattr(spec, "ra", None)
dec = getattr(spec, "dec", None)
records.append({"objid": objid, "ra": ra, "dec": dec})
except Exception as exc:
print(f"Failed to load spectrum '{filename}' for predownload: {exc}")
return records
try:
if hasattr(self.SpecClass, "count_in_file"):
total = int(self.SpecClass.count_in_file(self.spectra))
else:
with fits.open(self.spectra) as hdul:
total = len(hdul) - 1
except Exception as exc:
print(f"Failed to determine spectrum count for predownload: {exc}")
return records
for ext in range(1, total + 1):
try:
spec = self.SpecClass(self.spectra, ext=ext)
objid = getattr(spec, "objid", None)
ra = getattr(spec, "ra", None)
dec = getattr(spec, "dec", None)
records.append({"objid": objid, "ra": ra, "dec": dec})
except Exception as exc:
print(f"Failed to load spectrum ext={ext} for predownload: {exc}")
return records
def _run_bulk_predownload(self):
"""Run one-time bulk predownload and ask user to restart."""
print(f"Preparing bulk cutout download into '{self.buffer_dir}' ...")
records = self._collect_cutout_records()
if not records:
print("No valid spectra records found for bulk predownload. Continuing with on-the-fly downloads.")
return False
summary = predownload_cutouts(
records=records,
buffer_dir=self.buffer_dir,
surveys=EUCLID_CUTOUT_SURVEYS,
size_arcsec=10,
progress_callback=print_cli_progress,
)
attempted = max(summary["total"] - summary["skipped"], 0)
failed_total = summary["failed"] + summary["no_data"]
fail_rate = failed_total / attempted if attempted > 0 else 0.0
print(
"Bulk cutout download finished. "
f"downloaded={summary['downloaded']}, skipped={summary['skipped']}, "
f"no_data={summary['no_data']}, failed={summary['failed']}."
)
if fail_rate >= 0.5:
print(
"Bulk predownload failure rate is high; continuing now with on-the-fly downloads "
"and disabling background prefetch for this run."
)
self._disable_background_prefetch = True
return False
print("Please run the program again to launch the Qt window with the prebuilt cutout buffer.")
return True
[docs]
def run(self):
if self._skip_window:
return
exit_code = self.app.exec_()
sys.exit(exit_code)