Source code for specbox.qtmodule.qtmodule_enhanced

from PySide6.QtGui import QCursor, QFont, QPixmap
from PySide6.QtCore import Qt, QThread, Signal, QTimer, QThreadPool, QRunnable
from PySide6.QtWidgets import (QApplication, QFrame, QWidget, QSlider, QHBoxLayout, 
                               QVBoxLayout, QGridLayout, QDoubleSpinBox, QPushButton, 
                               QLabel, QButtonGroup, QRadioButton, QGroupBox, 
                               QFileDialog, QMessageBox, QComboBox, QProgressBar, QSpinBox,
                               QScrollArea, QCheckBox, QDialog, QTextEdit, QSplitter,
                               QLineEdit)
import sys
from ..basemodule import *
import pyqtgraph as pg
import numpy as np
from astropy.table import Table
from astropy.io import fits
from astropy.stats import sigma_clip
from astropy.coordinates import SkyCoord
from astropy import units as u
import pandas as pd
from importlib.resources import files
from pathlib import Path
import os
import requests
from urllib.parse import urlencode
import tempfile
from PIL import Image
import io
import json
import hashlib
import time
import re
import concurrent.futures
import threading
from datetime import datetime, timezone
import getpass
from specutils import Spectrum
from importlib.metadata import PackageNotFoundError, version as dist_version
from ..auxmodule.cutout_download import (
    EUCLID_CUTOUT_SURVEYS,
    fetch_cutout,
    get_cache_filename,
    is_no_data_error,
    is_valid_cutout_target,
    load_cutout_from_cache,
    predownload_cutouts,
    print_cli_progress,
    save_cutout_to_cache,
)
from ..auxmodule.external_redshift import normalize_redshift_lookup_key

# locate the data files in the package
data_path = Path(files("specbox").joinpath("data/templates"))
fits_file = data_path / "qso1" / "optical_nir_qso_template_v1.fits"
ragn_dr1_template_file = data_path / "qso1" / "ragn_dr1.fits"
type2_template_file = data_path / "qso2" / "type2_quasar_composite.csv"
type2_euclid_template_file = data_path / "qso2" / "ragn_na.fits"
tb_temp = Table.read(str(fits_file))
tb_temp.rename_columns(['wavelength', 'flux'], ['Wave', 'Flux'])
try:
    viewer_version = dist_version("specbox")
except PackageNotFoundError:
    viewer_version = "0.0.0"

# Rest-frame emission lines used for template annotations.
# All values are in Angstrom.
_TEMPLATE_EMISSION_LINES = [
    ("Ly α", 1215.67),
    ("C IV", 1549.06),
    ("C III]", 1908.73),
    ("Mg II", 2798.75),
    ("[O II]", 3728.48),
    ("Hβ", 4862.68),
    ("[O III]", 4960.30),
    ("[O III]", 5008.24),
    ("Hα", 6564.61),
    ("O I", 8448.7),
    ("[S III]", 9071.1),
    ("[S III]", 9533.2),
    ("Pa δ", 10052.1),
    ("He I", 10833.2),
    ("Pa γ", 10941.1),
    ("O I", 11290.0),
    ("Pa β", 12821.6),
]
_TEMPLATE_COLOR = (220, 0, 0, 230)
_CANONICAL_CLASS_TO_DISPLAY = {
    "QSO_DEFAULT": "QSO(Default)",
    "QSO": "QSO",
    "QSO_NARROW": "QSO(Narrow)",
    "QSO_BAL": "QSO(BAL)",
    "QSO_FELOBAL": "QSO(FeLoBAL)",
    "LIKELY_Q": "LIKELY_Q",
    "GALAXY": "GALAXY",
    "STAR": "STAR",
    "UNKNOWN": "UNKNOWN",
    "BAD": "BAD",
}
_CLASS_LABEL_ALIASES = {
    "QSO(Default)": "QSO_DEFAULT",
    "QSO_DEFAULT": "QSO_DEFAULT",
    "QSO": "QSO",
    "QSO(narrow)": "QSO_NARROW",
    "QSO(Narrow)": "QSO_NARROW",
    "QSO_NARROW": "QSO_NARROW",
    "QSO(BAL)": "QSO_BAL",
    "QSO_BAL": "QSO_BAL",
    "QSO(FeLoBAL)": "QSO_FELOBAL",
    "QSO_FELOBAL": "QSO_FELOBAL",
    "LIKELY": "LIKELY_Q",
    "LIKELY_Q": "LIKELY_Q",
    "GALAXY": "GALAXY",
    "STAR": "STAR",
    "UNKNOWN": "UNKNOWN",
    "BAD": "BAD",
}


[docs] def normalize_class_label(value): if value is None: return "" text = str(value).strip() if not text or text.lower() == "nan": return "" return _CLASS_LABEL_ALIASES.get(text, text)
[docs] def display_class_label(value): canonical = normalize_class_label(value) return _CANONICAL_CLASS_TO_DISPLAY.get(canonical, canonical)
[docs] def default_reviewer_username(): try: username = getpass.getuser() except Exception: username = "" return str(username or "").strip()
[docs] def normalize_data_release(value, *, aimsz_review=False): if value is None: return None text = str(value).strip() if not text or text.lower() == "nan": return None if aimsz_review and "desi" in text.lower() and "dr1" in text.lower(): return "DESI-DR1" return text
[docs] def load_template_table(template_path): template_path = Path(template_path) suffix = template_path.suffix.lower() if suffix == ".csv": table = Table.read(str(template_path), format="ascii.csv") else: table = Table.read(str(template_path)) columns = {str(col).lower(): col for col in table.colnames} wave_col = columns.get("wavelength", table.colnames[0] if len(table.colnames) > 0 else None) flux_col = columns.get("flux", table.colnames[1] if len(table.colnames) > 1 else None) if wave_col is None or flux_col is None: raise ValueError(f"Template has unexpected columns: {template_path}") return { "wave": np.asarray(table[wave_col], dtype=float), "flux": np.asarray(table[flux_col], dtype=float), }
def _is_dataframe_backed_spectrum_path(path): return SpecEuclid1d._is_dataframe_backed_path(path) def _choose_dual_pair_key_column(rgs_df, bgs_df): for column in ("extname", "objid"): if column in rgs_df.columns and column in bgs_df.columns: return column return None def _ordered_shared_dual_pair_keys(rgs_file, bgs_file): if not (_is_dataframe_backed_spectrum_path(rgs_file) and _is_dataframe_backed_spectrum_path(bgs_file)): return [], None rgs_df = SpecPandasRow._read_dataframe_file(rgs_file, file_format="parquet") bgs_df = SpecPandasRow._read_dataframe_file(bgs_file, file_format="parquet") key_column = _choose_dual_pair_key_column(rgs_df, bgs_df) if key_column is None: return [], None bgs_keys = { str(value).strip() for value in bgs_df[key_column] if value is not None and str(value).strip() not in ("", "nan", "None") } ordered = [] seen = set() for value in rgs_df[key_column]: key = str(value).strip() if key in ("", "nan", "None") or key not in bgs_keys or key in seen: continue ordered.append(key) seen.add(key) return ordered, key_column
[docs] class ImageCutoutWidget(QWidget): """Widget for displaying astronomical image cutouts.""" QA_CONTAMINATION_BIT = 1 QA_UNUSABLE_BIT = 2 def __init__(self, buffer_dir=None): super().__init__() self.setFixedWidth(300) self._fetch_in_progress = False self.buffer_dir = Path(buffer_dir) if buffer_dir else None if self.buffer_dir: self.buffer_dir.mkdir(exist_ok=True) self.setup_ui()
[docs] def setup_ui(self): layout = QVBoxLayout() # Coordinate display self.coord_label = QLabel("RA, DEC = -, -") self.coord_label.setFont(QFont("Arial", 14)) self.coord_label.setStyleSheet("color: blue; background-color: #f0f0f0; padding: 2px;") self.coord_label.setTextInteractionFlags(Qt.TextSelectableByMouse | Qt.TextSelectableByKeyboard) layout.addWidget(self.coord_label) scroll = QScrollArea() scroll.setWidgetResizable(True) scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarAsNeeded) scroll.setVerticalScrollBarPolicy(Qt.ScrollBarAsNeeded) scroll.setMaximumHeight(400) # Limit height to make layout more compact self.images_widget = QWidget() self.images_layout = QVBoxLayout() self.images_layout.setSpacing(5) # Reduce spacing between images self.images_widget.setLayout(self.images_layout) # scroll.setWidget(self.images_widget) # layout.addWidget(scroll) layout.addWidget(self.images_widget) # Title + QA + toggle header_layout = QVBoxLayout() qa_group = QGroupBox("Spec-image QA:") self.qa_group = qa_group qa_group.setFont(QFont("Arial", 14, QFont.Bold)) qa_layout = QVBoxLayout() self.qa_contamination_cb = QCheckBox("Contamination from nearby\nsource(s)") self.qa_unusable_cb = QCheckBox("Unusable spectrum due to\ndominating artifacts") self.qa_contamination_cb.setFont(QFont("Arial", 13)) self.qa_unusable_cb.setFont(QFont("Arial", 13)) qa_checkbox_style = "QCheckBox::indicator { width: 24px; height: 24px; }" self.qa_contamination_cb.setStyleSheet(qa_checkbox_style) self.qa_unusable_cb.setStyleSheet(qa_checkbox_style) qa_layout.addWidget(self.qa_contamination_cb) qa_layout.addWidget(self.qa_unusable_cb) qa_group.setLayout(qa_layout) header_layout.addWidget(qa_group) # Toggle for auto-fetch auto_fetch_row = QHBoxLayout() self.auto_fetch_cb = QCheckBox("Auto-fetch image cutouts") self.auto_fetch_cb.setChecked(True) self.auto_fetch_cb.setToolTip("Automatically fetch images when coordinates change") auto_fetch_row.addWidget(self.auto_fetch_cb) auto_fetch_row.addStretch() header_layout.addLayout(auto_fetch_row) header_widget = QWidget() header_widget.setLayout(header_layout) layout.addWidget(header_widget) # Controls controls_group = QGroupBox("Cutout Settings") controls_layout = QGridLayout() # Size control controls_layout.addWidget(QLabel("Size (arcsec):"), 0, 0) self.size_combo = QComboBox() self.size_combo.addItems(["5", "10", "15", "30", "60"]) self.size_combo.setCurrentText("10") controls_layout.addWidget(self.size_combo, 0, 1) # Removed Load Local button as we now use buffer folder controls_group.setLayout(controls_layout) layout.addWidget(controls_group) # Progress bar for downloads (moved below images) self.progress = QProgressBar() self.progress.setVisible(False) layout.addWidget(self.progress) self.setLayout(layout)
[docs] class AIMSZReviewPanel(QWidget): """Read-only review context plus reviewer-editable fields for AIMS-z.""" qa_flag_changed = Signal(int) notes_changed = Signal(str) reviewer_changed = Signal(str) def __init__(self, default_reviewer=""): super().__init__() self._updating = False self._context_labels = {} self._default_reviewer = str(default_reviewer or "") self.setFixedWidth(320) self._setup_ui() def _setup_ui(self): layout = QVBoxLayout() context_group = QGroupBox("AIMS-z Review Context") context_layout = QGridLayout() context_fields = [ ("object_id", "object_id"), ("targetid", "targetid"), ("review_priority_tier", "tier"), ("review_score", "score"), ("review_rank_within_tier", "rank"), ("review_slice_label", "slice"), ("z_ref", "z_ref"), ("z_ml_expect", "z_ml"), ("z_pcf_best", "z_pcf"), ("pcf_template_best", "pcf_template"), ("pcf_score_best", "pcf_score"), ] for row_idx, (field, label) in enumerate(context_fields): context_layout.addWidget(QLabel(f"{label}:"), row_idx, 0) value_label = QLabel("-") value_label.setWordWrap(True) value_label.setTextInteractionFlags(Qt.TextSelectableByMouse | Qt.TextSelectableByKeyboard) context_layout.addWidget(value_label, row_idx, 1) self._context_labels[field] = value_label context_group.setLayout(context_layout) layout.addWidget(context_group) qa_group = QGroupBox("Review QA") qa_layout = QVBoxLayout() self.qa_contamination_cb = QCheckBox("Contamination from nearby\nsource(s)") self.qa_unusable_cb = QCheckBox("Unusable spectrum due to\ndominating artifacts") qa_checkbox_style = "QCheckBox::indicator { width: 24px; height: 24px; }" self.qa_contamination_cb.setStyleSheet(qa_checkbox_style) self.qa_unusable_cb.setStyleSheet(qa_checkbox_style) qa_layout.addWidget(self.qa_contamination_cb) qa_layout.addWidget(self.qa_unusable_cb) qa_group.setLayout(qa_layout) layout.addWidget(qa_group) editor_group = QGroupBox("Reviewer Edits") editor_layout = QGridLayout() self.notes_label = QLabel("Notes:") editor_layout.addWidget(self.notes_label, 0, 0) self.notes_edit = QTextEdit() self.notes_edit.setMinimumHeight(90) self.notes_edit.setFocusPolicy(Qt.ClickFocus) editor_layout.addWidget(self.notes_edit, 0, 1) self.reviewer_label = QLabel("Reviewer:") editor_layout.addWidget(self.reviewer_label, 1, 0) self.reviewer_edit = QLineEdit() self.reviewer_edit.setFocusPolicy(Qt.ClickFocus) editor_layout.addWidget(self.reviewer_edit, 1, 1) self.reviewed_at_title = QLabel("Reviewed at:") editor_layout.addWidget(self.reviewed_at_title, 2, 0) self.reviewed_at_label = QLabel("-") self.reviewed_at_label.setTextInteractionFlags(Qt.TextSelectableByMouse | Qt.TextSelectableByKeyboard) editor_layout.addWidget(self.reviewed_at_label, 2, 1) editor_group.setLayout(editor_layout) layout.addWidget(editor_group) layout.addStretch() self.setLayout(layout) self.qa_contamination_cb.stateChanged.connect(self._emit_qa_flag) self.qa_unusable_cb.stateChanged.connect(self._emit_qa_flag) self.notes_edit.textChanged.connect(self._emit_notes) self.reviewer_edit.textChanged.connect(self._emit_reviewer) self.hide_notes_editor()
[docs] def hide_notes_editor(self): self.notes_label.setVisible(False) self.notes_edit.setVisible(False)
[docs] def set_default_reviewer(self, reviewer): self._default_reviewer = str(reviewer or "")
@staticmethod def _qa_flag_from_bits(contamination, unusable): flag = 0 if contamination: flag |= ImageCutoutWidget.QA_CONTAMINATION_BIT if unusable: flag |= ImageCutoutWidget.QA_UNUSABLE_BIT return int(flag) def _emit_qa_flag(self): if self._updating: return self.qa_flag_changed.emit( self._qa_flag_from_bits( self.qa_contamination_cb.isChecked(), self.qa_unusable_cb.isChecked(), ) ) def _emit_notes(self): if self._updating: return self.notes_changed.emit(self.notes_edit.toPlainText()) def _emit_reviewer(self, text): if self._updating: return self.reviewer_changed.emit(text)
[docs] def set_review_context(self, spec, record=None): self._updating = True try: for field, label in self._context_labels.items(): value = getattr(spec, field, None) if value is None and isinstance(record, dict): value = record.get(field) try: is_missing = value is None or (isinstance(value, str) and value.strip() == "") or pd.isna(value) except Exception: is_missing = value is None or (isinstance(value, str) and value.strip() == "") label.setText("-" if is_missing else str(value)) qa_flag = 0 if not isinstance(record, dict) else int(record.get("qa_flag", 0) or 0) self.qa_contamination_cb.setChecked((qa_flag & ImageCutoutWidget.QA_CONTAMINATION_BIT) != 0) self.qa_unusable_cb.setChecked((qa_flag & ImageCutoutWidget.QA_UNUSABLE_BIT) != 0) reviewer = self._default_reviewer if not isinstance(record, dict) else str(record.get("reviewer", "") or self._default_reviewer) self.notes_edit.setPlainText("" if not isinstance(record, dict) else str(record.get("notes", "") or "")) self.reviewer_edit.setText(reviewer) reviewed_at = "" if not isinstance(record, dict) else str(record.get("reviewed_at", "") or "") self.reviewed_at_label.setText(reviewed_at if reviewed_at else "-") finally: self._updating = False
[docs] def get_qa_flag(self): flag = 0 if self.qa_contamination_cb.isChecked(): flag |= self.QA_CONTAMINATION_BIT if self.qa_unusable_cb.isChecked(): flag |= self.QA_UNUSABLE_BIT return int(flag)
[docs] def set_qa_flag(self, qa_flag): try: flag = int(qa_flag) except Exception: flag = 0 self.qa_contamination_cb.blockSignals(True) self.qa_unusable_cb.blockSignals(True) self.qa_contamination_cb.setChecked((flag & self.QA_CONTAMINATION_BIT) != 0) self.qa_unusable_cb.setChecked((flag & self.QA_UNUSABLE_BIT) != 0) self.qa_contamination_cb.blockSignals(False) self.qa_unusable_cb.blockSignals(False)
# Removed load_local_cutouts method since we removed the Load Local button
[docs] def load_online_cutouts(self, ra, dec, objid=None): """Load cutouts from HiPS2FITS service for given coordinates.""" if not self.auto_fetch_cb.isChecked(): return if ra is None or dec is None or np.isnan(ra) or np.isnan(dec): self.coord_label.setText("RA, DEC = -, -") self.add_status_message("No valid coordinates available") return # Update coordinate display self.coord_label.setText(f"RA, DEC = {ra:.6f}, {dec:.6f}") # Prevent multiple simultaneous fetches if self._fetch_in_progress: return self._fetch_in_progress = True self.clear_images() self.progress.setVisible(True) self.progress.setRange(0, 0) # Indeterminate progress # Only fetch color for now to avoid loading issues euclid_surveys = list(EUCLID_CUTOUT_SURVEYS) size_arcsec = float(self.size_combo.currentText()) # Fetch bands sequentially to avoid threading issues try: for survey, band_name in euclid_surveys: try: # Check cache first cached_data = None if objid: cached_data = self.load_cutout_from_cache(objid, survey) if cached_data is not None: self.add_status_message(f"Loading {band_name} from cache...") self.add_image_from_rgb_array(cached_data, band_name) else: self.add_status_message(f"Fetching {band_name}...") # Use JPEG format for faster loading result = fetch_cutout( ra=ra, dec=dec, survey_name=survey, size_arcsec=size_arcsec, width=150, height=150, ) if result is not None: # Save to cache if objid: self.save_cutout_to_cache(objid, survey, result) # result is a numpy array for JPEG format (RGB) self.add_image_from_rgb_array(result, band_name) else: self.add_status_message(f"No data returned for {band_name}") except Exception as e: self.add_status_message(f"Error fetching {band_name}: {str(e)}") except Exception as e: self.add_status_message(f"General error in image fetching: {str(e)}") finally: self.progress.setVisible(False) self._fetch_in_progress = False
[docs] def add_image_from_pil(self, pil_image, label=""): """Add PIL image directly to display (for JPEG format).""" try: # Convert PIL image to QPixmap with io.BytesIO() as buffer: pil_image.save(buffer, format='PNG') pixmap = QPixmap() pixmap.loadFromData(buffer.getvalue()) self.add_image_widget(pixmap, label) except Exception as e: self.add_status_message(f"Error processing image {label}: {str(e)}")
[docs] def add_image_from_array(self, data, label=""): """Add image from numpy array to display.""" try: # Simple scaling - could be improved data_scaled = np.nan_to_num(data) if np.all(data_scaled == 0): self.add_status_message(f"Empty image data for {label}") return # Use arcsinh scaling like in notebook for VIS band if "VIS" in label or "$I_" in label: data_scaled = np.arcsinh(data_scaled * 500) else: data_scaled = np.arcsinh(data_scaled) vmin, vmax = np.percentile(data_scaled[data_scaled > 0], [1, 99]) if vmax > vmin: data_scaled = np.clip((data_scaled - vmin) / (vmax - vmin) * 255, 0, 255) else: data_scaled = np.zeros_like(data_scaled) # Convert to QPixmap height, width = data_scaled.shape image = Image.fromarray(data_scaled.astype(np.uint8), mode='L') # Convert PIL to QPixmap with io.BytesIO() as buffer: image.save(buffer, format='PNG') pixmap = QPixmap() pixmap.loadFromData(buffer.getvalue()) self.add_image_widget(pixmap, label) except Exception as e: self.add_status_message(f"Error processing image {label}: {str(e)}")
[docs] def add_image_from_rgb_array(self, rgb_data, label=""): """Add RGB image from numpy array (for JPEG format).""" try: # rgb_data is (height, width, 3) uint8 array if len(rgb_data.shape) != 3 or rgb_data.shape[2] != 3: self.add_status_message(f"Expected RGB array for {label}, got shape {rgb_data.shape}") return # Convert numpy array directly to PIL Image (RGB mode) image = Image.fromarray(rgb_data, mode='RGB') # Convert PIL to QPixmap with io.BytesIO() as buffer: image.save(buffer, format='PNG') pixmap = QPixmap() pixmap.loadFromData(buffer.getvalue()) self.add_image_widget(pixmap, label) except Exception as e: self.add_status_message(f"Error processing RGB image {label}: {str(e)}")
[docs] def create_composite_image(self, ra, dec, size_arcsec): """Create RGB composite from multiple bands.""" try: # For now, just add a placeholder for composite self.add_status_message("Composite image creation not yet implemented") except Exception as e: self.add_status_message(f"Error creating composite: {str(e)}")
[docs] def add_image_from_fits(self, fits_path, label=""): """Add FITS image to display.""" try: with fits.open(fits_path) as hdul: data = hdul[0].data if data is not None: # Convert to displayable image # Simple scaling - could be improved data_scaled = np.nan_to_num(data) vmin, vmax = np.percentile(data_scaled, [1, 99]) data_scaled = np.clip((data_scaled - vmin) / (vmax - vmin) * 255, 0, 255) # Convert to QPixmap height, width = data_scaled.shape image = Image.fromarray(data_scaled.astype(np.uint8), mode='L') # Convert PIL to QPixmap with io.BytesIO() as buffer: image.save(buffer, format='PNG') pixmap = QPixmap() pixmap.loadFromData(buffer.getvalue()) self.add_image_widget(pixmap, label) except Exception as e: self.add_status_message(f"Error loading FITS {fits_path}: {str(e)}")
[docs] def add_image_from_file(self, file_path): """Add image from file (FITS or regular image formats).""" path = Path(file_path) if path.suffix.lower() in ['.fits', '.fit']: self.add_image_from_fits(file_path, path.stem) else: # Regular image file pixmap = QPixmap(file_path) if not pixmap.isNull(): self.add_image_widget(pixmap, path.stem) else: self.add_status_message(f"Could not load image: {file_path}")
[docs] def add_image_widget(self, pixmap, label=""): """Add image widget to layout.""" container = QWidget() layout = QVBoxLayout() layout.setSpacing(3) # Reduce spacing between label and image layout.setContentsMargins(2, 2, 2, 2) # Reduce margins if label: label_widget = QLabel(label) label_widget.setAlignment(Qt.AlignCenter) label_widget.setFont(QFont("Arial", 16)) # Smaller font label_widget.setStyleSheet("margin-bottom: 2px;") # Reduce bottom margin layout.addWidget(label_widget) image_label = QLabel() # Scale image to fit width while maintaining aspect ratio scaled_pixmap = pixmap.scaled(250, 250, Qt.KeepAspectRatio, Qt.SmoothTransformation) image_label.setPixmap(scaled_pixmap) image_label.setAlignment(Qt.AlignCenter) image_label.setStyleSheet("border: 1px solid gray;") layout.addWidget(image_label) container.setLayout(layout) self.images_layout.addWidget(container)
[docs] def add_status_message(self, message): """Add status message to display.""" label = QLabel(message) label.setWordWrap(True) label.setStyleSheet("color: green; font-style: italic;") self.images_layout.addWidget(label) # Auto-remove after 3 seconds with weak reference import weakref weak_label = weakref.ref(label) QTimer.singleShot(3000, lambda: self.remove_widget(weak_label()) if weak_label() else None)
[docs] def remove_widget(self, widget): """Remove widget from layout.""" try: if widget and hasattr(widget, 'parent') and widget.parent(): self.images_layout.removeWidget(widget) widget.deleteLater() except RuntimeError: # Widget already deleted, ignore pass
[docs] def clear_images(self): """Clear all displayed images.""" while self.images_layout.count(): child = self.images_layout.takeAt(0) if child.widget(): child.widget().deleteLater()
[docs] def get_cache_filename(self, objid, survey_name): """Get cache filename for an object and survey.""" return get_cache_filename(self.buffer_dir, objid, survey_name)
[docs] def save_cutout_to_cache(self, objid, survey_name, image_data): """Save cutout image to cache.""" if not self.buffer_dir: return try: save_cutout_to_cache(self.buffer_dir, objid, survey_name, image_data) except Exception as e: cache_file = self.get_cache_filename(objid, survey_name) print(f"Error saving cache file {cache_file}: {e}")
[docs] def load_cutout_from_cache(self, objid, survey_name): """Load cutout from cache if exists.""" if not self.buffer_dir: return None try: return load_cutout_from_cache(self.buffer_dir, objid, survey_name) except Exception as e: cache_file = self.get_cache_filename(objid, survey_name) print(f"Error loading cache file {cache_file}: {e}") return None
[docs] def prefetch_cutouts_background(self, spectrum_list, current_index, num_prefetch=None): """Prefetch cutouts for upcoming spectra in the background.""" if not self.buffer_dir or current_index >= len(spectrum_list) - 1: return def prefetch_worker(): # Download all remaining cutouts if num_prefetch is None if num_prefetch is None: end_index = len(spectrum_list) else: end_index = min(current_index + num_prefetch + 1, len(spectrum_list)) surveys = list(EUCLID_CUTOUT_SURVEYS) for i in range(current_index + 1, end_index): try: # Load spectrum to get coordinates and objid spec_data = spectrum_list[i] if hasattr(spec_data, 'ra') and hasattr(spec_data, 'dec') and hasattr(spec_data, 'objid'): ra, dec, objid = spec_data.ra, spec_data.dec, spec_data.objid else: continue # Check if cutout already exists for survey, _ in surveys: if not is_valid_cutout_target(objid, ra, dec): continue if self.load_cutout_from_cache(objid, survey) is not None: continue # Already cached # Download cutout try: result = fetch_cutout( ra=ra, dec=dec, survey_name=survey, size_arcsec=10, width=150, height=150, ) if result is not None: self.save_cutout_to_cache(objid, survey, result) time.sleep(0.1) except Exception as e: if not is_no_data_error(e): print(f"Error prefetching cutout for object {objid}: {e}") time.sleep(0.1) except Exception as e: print(f"Error processing spectrum {i} for prefetch: {e}") # Run prefetch in background thread threading.Thread(target=prefetch_worker, daemon=True).start()
[docs] class TemplateManager: """Manages spectrum templates.""" def __init__(self): self.templates = {} self.current_template = "Type 1" self.load_default_templates()
[docs] def load_default_templates(self): """Load default templates.""" # Load Type 1 template (existing) self.templates["Type 1"] = { 'wave': tb_temp['Wave'].data, 'flux': tb_temp['Flux'].data, 'description': 'Type 1 AGN/QSO Template' } # Load ragn_dr1 template from packaged FITS file when available. self.templates["ragn_dr1"] = None try: if ragn_dr1_template_file.exists(): ragn_template = load_template_table(ragn_dr1_template_file) self.templates["ragn_dr1"] = { 'wave': ragn_template['wave'], 'flux': ragn_template['flux'], 'description': 'ragn_dr1 AGN/QSO Template' } else: print(f"ragn_dr1 template file not found: {ragn_dr1_template_file}") except Exception as e: print(f"Failed to load ragn_dr1 template from {ragn_dr1_template_file}: {e}") # Load Type 2 template from packaged CSV file when available. self.templates["Type 2"] = None try: if type2_template_file.exists(): type2_template = load_template_table(type2_template_file) self.templates["Type 2"] = { 'wave': type2_template['wave'], 'flux': type2_template['flux'], 'description': 'Type 2 AGN/QSO Template' } else: print(f"Type 2 template file not found: {type2_template_file}") except Exception as e: print(f"Failed to load Type 2 template from {type2_template_file}: {e}") self.templates["type2_euclid"] = None try: if type2_euclid_template_file.exists(): type2_euclid_template = load_template_table(type2_euclid_template_file) self.templates["type2_euclid"] = { 'wave': type2_euclid_template['wave'], 'flux': type2_euclid_template['flux'], 'description': 'Type 2 Euclid AGN/QSO Template (ragn_na)' } else: print(f"type2_euclid template file not found: {type2_euclid_template_file}") except Exception as e: print(f"Failed to load type2_euclid template from {type2_euclid_template_file}: {e}")
[docs] def get_template(self, template_name): """Get template by name.""" return self.templates.get(template_name)
[docs] def get_available_templates(self): """Get list of available template names.""" return [name for name, template in self.templates.items() if template is not None]
[docs] def add_template(self, name, wave, flux, description=""): """Add new template.""" self.templates[name] = { 'wave': wave, 'flux': flux, 'description': description }
[docs] class PGSpecPlotEnhanced(pg.PlotWidget): """Enhanced version of PGSpecPlot with image cutouts and template switching.""" coordinate_changed = Signal(float, float) # Signal for coordinate updates current_spec_changed = Signal() record_committed = Signal(object) def __init__(self, spectra, SpecClass=SpecEuclid1d, initial_counter=0, z_max=5.0, history_dict=None, euclid_fits=None, rgs_file=None, bgs_file=None, ext=None, extname=None, dual_good_pixels_only=False, external_redshift_lookup=None, external_redshift_key="object_id"): super().__init__() self.SpecClass = SpecClass self._is_aimsz_review = issubclass(SpecClass, SpecAIMSZReview) self.template_manager = TemplateManager() self.euclid_fits = euclid_fits self.external_redshift_lookup = dict(external_redshift_lookup or {}) self.external_redshift_key = str(external_redshift_key or "object_id") self._euclid_overlay_cache = {} self.dual_rgs_file = rgs_file self.dual_bgs_file = bgs_file self.dual_ext = ext self.dual_extname = extname self.dual_good_pixels_only = bool(dual_good_pixels_only) self._dual_mode = bool(self.dual_rgs_file and self.dual_bgs_file) self._dual_parquet_mode = self._dual_mode and _is_dataframe_backed_spectrum_path(self.dual_rgs_file) and _is_dataframe_backed_spectrum_path(self.dual_bgs_file) self._dual_pair_keys = [] self._dual_pair_key_column = None self.spec_dual = None self._legend = None self._observed_wmin = None self._observed_wmax = None self._annotation_wave = None self._annotation_flux = None self._view_lock_active = False self._locked_view_range = None self._suspend_view_lock_updates = False self.downsampling_enabled = False self._downsampling_auto = False self._downsampling_method = "mean" self._downsampling_clip_to_view = True # ``spectra`` can either be a FITS file containing multiple extensions # or a list of individual spectrum files. if self._dual_mode: self.speclist = None self.specfile = self.dual_rgs_file if self._dual_parquet_mode: self._dual_pair_keys, self._dual_pair_key_column = _ordered_shared_dual_pair_keys( self.dual_rgs_file, self.dual_bgs_file, ) if self.dual_ext is not None or self.dual_extname not in (None, ""): self.len_list = 1 else: self.len_list = len(self._dual_pair_keys) elif self.dual_ext is not None or self.dual_extname not in (None, ""): self.len_list = 1 else: rgs_count = self._count_hdus_in_file(self.dual_rgs_file) bgs_count = self._count_hdus_in_file(self.dual_bgs_file) if rgs_count != bgs_count: print( f"Warning: RGS/BGS extension count mismatch ({rgs_count} vs {bgs_count}); " f"using min={min(rgs_count, bgs_count)}." ) self.len_list = min(rgs_count, bgs_count) elif isinstance(spectra, (list, tuple)): self.speclist = list(spectra) self.specfile = None self.len_list = len(self.speclist) else: self.specfile = spectra if hasattr(SpecClass, "count_in_file"): self.len_list = int(SpecClass.count_in_file(spectra)) else: with fits.open(spectra) as hdul: self.len_list = len(hdul) - 1 self.speclist = None if initial_counter >= self.len_list: print("No more spectra to plot.\n\t Plotting the first spectrum.") initial_counter = 0 self.history = history_dict if history_dict is not None else {} self.setBackground('w') self.showGrid(x=True, y=True) self.setMouseEnabled(x=True, y=True) self.setLogMode(x=False, y=False) self.setAspectLocked(False) left_axis = self.getAxis('left') left_axis.enableAutoSIPrefix(False) self.enableAutoRange() self.vb = self.getViewBox() self.vb.setMouseMode(self.vb.RectMode) self.vb.sigRangeChangedManually.connect(self._on_manual_range_changed) self.z_min = 0.0 self.z_max = z_max self.base_z_step = 0.001 self.slider_min = 0 self.slider_max = int((1/self.base_z_step) * np.log((1+self.z_max)/(1+self.z_min))) self.message = '' self.counter = initial_counter self._last_committed_objid = None # Create slider and spin box for redshift control self.slider = QSlider(Qt.Horizontal) self.slider.setMinimum(self.slider_min) self.slider.setMaximum(self.slider_max) self.slider.setTickPosition(QSlider.TicksBelow) self.slider.setTickInterval(max(1, int((self.slider_max - self.slider_min) / 10))) self.slider.valueChanged.connect(self.slider_changed) # Create spectrum info title box self.spectrum_info_label = QLabel() self.spectrum_info_label.setFont(QFont("Arial", 14)) self.spectrum_info_label.setStyleSheet(""" QLabel { background-color: #f0f0f0; border: 2px solid #666666; padding: 8px; color: black; border-radius: 5px; } """) self.spectrum_info_label.setAlignment(Qt.AlignCenter) self.spectrum_info_label.setText("Loading spectrum info...") self.redshiftSpin = QDoubleSpinBox() self.redshiftSpin.setMinimum(self.z_min) self.redshiftSpin.setMaximum(self.z_max) self.redshiftSpin.setDecimals(4) self.redshiftSpin.setSingleStep(0.001) self.redshiftSpin.valueChanged.connect(self.spin_changed) self.slider.setStyleSheet(""" QSlider { background-color: white; } QSlider::groove:horizontal { border: 1px solid #999999; height: 8px; background: #b0c4de; margin: 2px 0; } QSlider::handle:horizontal { background: #6495ed; border: 1px solid #5c5c5c; width: 18px; margin: -2px 0; border-radius: 3px; } """) self.plot_next() @staticmethod def _is_sparcl_like_spec(spec): return ('SpecSparcl' in globals() and isinstance(spec, SpecSparcl)) or ( 'SpecAIMSZReview' in globals() and isinstance(spec, SpecAIMSZReview) ) def _is_current_sparcl_like(self): spec = getattr(self, "spec", None) return self._is_sparcl_like_spec(spec) def _default_class_token(self): return "QSO_DEFAULT" def _reviewer_default(self): reviewer = getattr(self, "_session_reviewer", None) if reviewer in (None, ""): reviewer = default_reviewer_username() return str(reviewer or "") @staticmethod def _is_positive_finite(value): try: value = float(value) except Exception: return False return np.isfinite(value) and value > 0 def _lookup_external_redshift_value(self, spec): if not self.external_redshift_lookup: return None candidate_attrs = [self.external_redshift_key] if self.external_redshift_key == "object_id": candidate_attrs.extend(["object_id", "objid", "extname"]) elif self.external_redshift_key == "objid": candidate_attrs.extend(["objid", "object_id", "extname"]) elif self.external_redshift_key == "extname": candidate_attrs.extend(["extname", "objid", "object_id"]) seen = set() for attr in candidate_attrs: if attr in seen: continue seen.add(attr) key = normalize_redshift_lookup_key(getattr(spec, attr, None)) if key is None: continue if key in self.external_redshift_lookup: return self.external_redshift_lookup[key] return None def _apply_external_redshift_to_spec(self, spec): z_ref = self._lookup_external_redshift_value(spec) if z_ref is None: return None spec.z_ref = z_ref return z_ref def _apply_external_redshift_overlay(self, spec, dual=None): applied = None if dual is not None: for arm in (getattr(dual, "rgs", None), getattr(dual, "bgs", None)): if arm is None: continue z_ref = self._apply_external_redshift_to_spec(arm) if applied is None and z_ref is not None: applied = z_ref if applied is None: applied = self._apply_external_redshift_to_spec(spec) elif getattr(spec, "z_ref", None) in (None, "", 0, 0.0): spec.z_ref = applied return applied def _prime_initial_redshift(self, spec): if self._is_positive_finite(getattr(spec, "z_vi", None)): return for attr in ("z_ref", "z_temp", "redshift"): value = getattr(spec, attr, None) if self._is_positive_finite(value): spec.z_vi = float(value) return def _new_history_record(self, spec, class_vi="", z_vi=None): return { "objname": getattr(spec, "objname", "Unknown"), "ra": getattr(spec, "ra", np.nan), "dec": getattr(spec, "dec", np.nan), "class_vi": normalize_class_label(class_vi), "z_vi": getattr(spec, "z_vi", 0.0) if z_vi is None else z_vi, "targetid": getattr(spec, "targetid", None), "data_release": normalize_data_release(getattr(spec, "data_release", None), aimsz_review=self._is_aimsz_review), "qa_flag": int(getattr(spec, "qa_flag", 0) or 0), "notes": str(getattr(spec, "notes", "") or ""), "reviewer": str(getattr(spec, "reviewer", "") or self._reviewer_default()), "reviewed_at": str(getattr(spec, "reviewed_at", "") or ""), "object_id": getattr(spec, "object_id", None), } def _coerce_history_record(self, spec, record): if isinstance(record, dict): coerced = dict(record) else: values = list(record) coerced = { "objname": values[0] if len(values) > 0 else getattr(spec, "objname", "Unknown"), "ra": values[1] if len(values) > 1 else getattr(spec, "ra", np.nan), "dec": values[2] if len(values) > 2 else getattr(spec, "dec", np.nan), "class_vi": values[3] if len(values) > 3 else "", "z_vi": values[4] if len(values) > 4 else getattr(spec, "z_vi", 0.0), "targetid": values[5] if len(values) > 5 else getattr(spec, "targetid", None), "data_release": values[6] if len(values) > 6 else getattr(spec, "data_release", None), "qa_flag": values[7] if len(values) > 7 else getattr(spec, "qa_flag", 0), "notes": values[8] if len(values) > 8 else "", "reviewer": values[9] if len(values) > 9 else "", "reviewed_at": values[10] if len(values) > 10 else "", "object_id": getattr(spec, "object_id", None), } coerced["class_vi"] = normalize_class_label(coerced.get("class_vi", "")) coerced["qa_flag"] = int(coerced.get("qa_flag", 0) or 0) coerced["notes"] = str(coerced.get("notes", "") or "") coerced["reviewer"] = str(coerced.get("reviewer", "") or self._reviewer_default()) coerced["reviewed_at"] = str(coerced.get("reviewed_at", "") or "") if coerced.get("targetid", None) is None: coerced["targetid"] = getattr(spec, "targetid", None) if coerced.get("data_release", None) is None: coerced["data_release"] = getattr(spec, "data_release", None) coerced["data_release"] = normalize_data_release(coerced.get("data_release", None), aimsz_review=self._is_aimsz_review) if coerced.get("objname", None) in (None, "", "nan"): coerced["objname"] = getattr(spec, "objname", "Unknown") if coerced.get("ra", None) is None: coerced["ra"] = getattr(spec, "ra", np.nan) if coerced.get("dec", None) is None: coerced["dec"] = getattr(spec, "dec", np.nan) if coerced.get("object_id", None) is None: coerced["object_id"] = getattr(spec, "object_id", None) return coerced def _history_record(self, spec=None, objid=None, create=False): if objid is None and spec is not None: objid = getattr(spec, "objid", None) if objid is None: return None record = self.history.get(objid) if record is None: if not create or spec is None: return None record = self._new_history_record(spec) self.history[objid] = record return record if isinstance(record, dict): return record if spec is not None: record = self._coerce_history_record(spec, record) self.history[objid] = record return record def _apply_history_to_spec(self, spec): record = self._history_record(spec=spec, create=False) if record is None: self._prime_initial_redshift(spec) return None spec.z_vi = record.get("z_vi", getattr(spec, "z_vi", 0.0)) spec.qa_flag = int(record.get("qa_flag", getattr(spec, "qa_flag", 0)) or 0) spec.class_vi = record.get("class_vi", getattr(spec, "class_vi", "")) if self._is_aimsz_review: spec.notes = record.get("notes", "") spec.reviewer = record.get("reviewer", self._reviewer_default()) spec.reviewed_at = record.get("reviewed_at", "") return record def _set_classification(self, spec, class_vi, z_vi=None): record = self._history_record(spec=spec, create=True) record["class_vi"] = normalize_class_label(class_vi) record["z_vi"] = getattr(spec, "z_vi", 0.0) if z_vi is None else z_vi record["qa_flag"] = int(getattr(spec, "qa_flag", 0) or 0) if self._is_aimsz_review: record["reviewer"] = str(getattr(spec, "reviewer", record.get("reviewer", "")) or self._reviewer_default()) spec.reviewer = record["reviewer"] record["targetid"] = record.get("targetid", getattr(spec, "targetid", None)) record["data_release"] = normalize_data_release( record.get("data_release", getattr(spec, "data_release", None)), aimsz_review=self._is_aimsz_review, ) record["objname"] = record.get("objname", getattr(spec, "objname", "Unknown")) record["ra"] = getattr(spec, "ra", record.get("ra", np.nan)) record["dec"] = getattr(spec, "dec", record.get("dec", np.nan)) if self._is_aimsz_review: spec.class_vi = record["class_vi"] self.history[spec.objid] = record return record def _history_rows_for_csv(self): rows = [] for objid_key, raw_record in self.history.items(): record = dict(raw_record) if isinstance(raw_record, dict) else self._coerce_history_record(self.spec, raw_record) if self._is_aimsz_review: rows.append( { "objid": objid_key, "targetid": record.get("targetid"), "ra": record.get("ra"), "dec": record.get("dec"), "data_release": record.get("data_release"), "class_vi": normalize_class_label(record.get("class_vi", "")), "z_vi": record.get("z_vi"), "qa_flag": int(record.get("qa_flag", 0) or 0), "notes": record.get("notes", ""), "reviewer": record.get("reviewer", "") or self._reviewer_default(), "reviewed_at": record.get("reviewed_at", ""), } ) else: rows.append( { "objid": objid_key, "objname": record.get("objname", "Unknown"), "ra": record.get("ra"), "dec": record.get("dec"), "class_vi": normalize_class_label(record.get("class_vi", "")), "z_vi": record.get("z_vi"), "targetid": record.get("targetid"), "data_release": record.get("data_release"), "qa_flag": int(record.get("qa_flag", 0) or 0), } ) return rows # Copy all existing methods from original PGSpecPlot @staticmethod def _count_hdus_in_file(filename): try: if hasattr(SpecEuclid1d, "count_in_file"): return int(SpecEuclid1d.count_in_file(filename)) except Exception: pass with fits.open(filename) as hdul: return max(0, len(hdul) - 1) def _set_legend(self, enabled): if self._legend is not None: try: self._legend.scene().removeItem(self._legend) except Exception: pass self._legend = None if enabled: self._legend = self.addLegend(offset=(10, 10)) def _load_spec(self, index_zero_based): """Load a spectrum by 0-based index from ``spectra``. Args: index_zero_based: 0-based spectrum index (0 = first spectrum) For FITS files: automatically converted to 1-based extension number """ if self.speclist is not None: # For list of files: use 0-based index directly filename = self.speclist[index_zero_based] spec = self.SpecClass(filename) else: # For multi-extension FITS: convert 0-based index to 1-based extension number spec = self.SpecClass(self.specfile, ext=index_zero_based + 1) self._ensure_spec_defaults(spec) return spec def _resolve_dual_selector(self, index_zero_based): if not self._dual_parquet_mode: ext = self.dual_ext if self.dual_ext is not None else index_zero_based + 1 extname = self.dual_extname if extname in ("",): extname = None return ext, extname if self.dual_extname not in (None, ""): return None, str(self.dual_extname) if self.dual_ext is not None: pair_index = int(self.dual_ext) - 1 else: pair_index = int(index_zero_based) if pair_index < 0 or pair_index >= len(self._dual_pair_keys): raise IndexError( f"Dual-arm parquet pair index {pair_index} out of range for {len(self._dual_pair_keys)} shared pairs" ) return None, self._dual_pair_keys[pair_index] def _load_dual_spec(self, index_zero_based): ext, extname = self._resolve_dual_selector(index_zero_based) dual = SpecEuclid1dDual( rgs_file=self.dual_rgs_file, bgs_file=self.dual_bgs_file, ext=ext, extname=extname, good_pixels_only=self.dual_good_pixels_only, ) spec = dual.rgs if dual.rgs is not None else dual.bgs if spec is None: raise RuntimeError("SpecEuclid1dDual returned no arm data.") self._ensure_spec_defaults(spec) self._apply_external_redshift_overlay(spec, dual=dual) return spec, dual def _load_current_spec(self, index_zero_based): if self._dual_mode: spec, dual = self._load_dual_spec(index_zero_based) self.spec_dual = dual return spec self.spec_dual = None spec = self._load_spec(index_zero_based) self._apply_external_redshift_overlay(spec) return spec def _ensure_spec_defaults(self, spec): """Ensure common attributes exist on ``spec``.""" if not hasattr(spec, 'z_vi'): spec.z_vi = getattr(spec, 'redshift', 0.0) if not hasattr(spec, 'z_ph'): spec.z_ph = getattr(spec, 'redshift', 0.0) if not hasattr(spec, 'z_gaia'): spec.z_gaia = None if not hasattr(spec, 'objid'): spec.objid = self.counter if not hasattr(spec, 'objname'): spec.objname = 'Unknown' if not hasattr(spec, 'qa_flag'): spec.qa_flag = 0 if self._is_aimsz_review: if not hasattr(spec, 'class_vi'): spec.class_vi = "" if not hasattr(spec, 'notes'): spec.notes = "" if not hasattr(spec, 'reviewer'): spec.reviewer = self._reviewer_default() elif not getattr(spec, 'reviewer', ""): spec.reviewer = self._reviewer_default() if not hasattr(spec, 'reviewed_at'): spec.reviewed_at = ""
[docs] def update_slider_and_spin(self): spec = self.spec initial_z = spec.z_vi if spec.z_vi > 0 else self.z_min initial_slider_value = int((1/self.base_z_step) * np.log((1+initial_z)/(1+self.z_min))) self.slider.blockSignals(True) self.slider.setValue(initial_slider_value) self.slider.blockSignals(False) self.redshiftSpin.blockSignals(True) self.redshiftSpin.setValue(initial_z) self.redshiftSpin.blockSignals(False)
def _on_manual_range_changed(self, *_args): if self._suspend_view_lock_updates: return self._capture_current_view_range() def _capture_current_view_range(self): try: x_range, y_range = self.vb.viewRange() except Exception: return if x_range is None or y_range is None: return if len(x_range) != 2 or len(y_range) != 2: return values = [x_range[0], x_range[1], y_range[0], y_range[1]] if not all(np.isfinite(values)): return self._locked_view_range = ( (float(x_range[0]), float(x_range[1])), (float(y_range[0]), float(y_range[1])), ) self._view_lock_active = True def _clear_view_lock(self): self._view_lock_active = False self._locked_view_range = None def _restore_locked_view(self): if not self._view_lock_active or self._locked_view_range is None: return False try: x_range, y_range = self._locked_view_range if not all(np.isfinite([x_range[0], x_range[1], y_range[0], y_range[1]])): return False self._suspend_view_lock_updates = True self.enableAutoRange(axis='x', enable=False) self.enableAutoRange(axis='y', enable=False) self.vb.setRange( xRange=(float(x_range[0]), float(x_range[1])), yRange=(float(y_range[0]), float(y_range[1])), padding=0.0, disableAutoRange=True, ) return True except Exception: return False finally: self._suspend_view_lock_updates = False
[docs] def slider_changed(self, slider_value): z = np.exp(self.base_z_step * slider_value) * (1+self.z_min) - 1 self.spec.z_vi = z self.redshiftSpin.blockSignals(True) self.redshiftSpin.setValue(z) self.redshiftSpin.blockSignals(False) self.clear() self.plot_single(preserve_view=self._view_lock_active)
[docs] def spin_changed(self, z_value): self.spec.z_vi = z_value slider_value = int((1/self.base_z_step) * np.log((1+z_value)/(1+self.z_min))) self.slider.blockSignals(True) self.slider.setValue(slider_value) self.slider.blockSignals(False) self.clear() self.plot_single(preserve_view=self._view_lock_active)
[docs] def plot_single(self, preserve_view=False): """Plot the spectrum without template.""" spec = self.spec dual = self.spec_dual is_sparcl = self._is_sparcl_like_spec(spec) is_euclid = getattr(spec, 'telescope', '').lower() == 'euclid' is_euclid_dual = dual is not None self._observed_wmin = None self._observed_wmax = None self._annotation_wave = None self._annotation_flux = None self.flux_sm = None self._set_legend(is_euclid_dual) annotation_wave_parts = [] annotation_flux_parts = [] # Follow original code pattern with sigma clipping wave_full = spec.wave.value if hasattr(spec.wave, 'value') else spec.wave flux_full = spec.flux.value if hasattr(spec.flux, 'value') else spec.flux if is_euclid_dual: dual_data = dual.for_redshift() rgs = dual_data.get("rgs", {}) bgs = dual_data.get("bgs", {}) wave_rgs_raw = rgs.get("wavelength") flux_rgs_raw = rgs.get("flux") wave_bgs_raw = bgs.get("wavelength") flux_bgs_raw = bgs.get("flux") wave_rgs = np.asarray(wave_rgs_raw if wave_rgs_raw is not None else [], dtype=float) flux_rgs = np.asarray(flux_rgs_raw if flux_rgs_raw is not None else [], dtype=float) wave_bgs = np.asarray(wave_bgs_raw if wave_bgs_raw is not None else [], dtype=float) flux_bgs = np.asarray(flux_bgs_raw if flux_bgs_raw is not None else [], dtype=float) wave_rgs_for_xlim = wave_rgs[np.isfinite(wave_rgs)] wave_bgs_for_xlim = wave_bgs[np.isfinite(wave_bgs)] finite_r = np.isfinite(wave_rgs) & np.isfinite(flux_rgs) finite_b = np.isfinite(wave_bgs) & np.isfinite(flux_bgs) wave_rgs = wave_rgs[finite_r] flux_rgs = flux_rgs[finite_r] wave_bgs = wave_bgs[finite_b] flux_bgs = flux_bgs[finite_b] if wave_rgs.size > 0: self.plot( wave_rgs, flux_rgs, pen=pg.mkPen((44, 160, 44, 220), width=2), name="RGS", antialias=True, ) gm_r = getattr(dual.rgs, "good_mask", None) if dual.rgs is not None else None if gm_r is not None and dual.rgs is not None and len(gm_r) == len(getattr(dual.rgs, "wave", [])): rw = np.asarray(dual.rgs.wave.value if hasattr(dual.rgs.wave, "value") else dual.rgs.wave, dtype=float) rf = np.asarray(dual.rgs.flux.value if hasattr(dual.rgs.flux, "value") else dual.rgs.flux, dtype=float) gm_r = np.asarray(gm_r, dtype=bool) & np.isfinite(rw) & np.isfinite(rf) if np.any(gm_r): self.plot(rw[gm_r], rf[gm_r], pen=pg.mkPen((44, 160, 44, 240), width=2), antialias=True) if wave_bgs.size > 0: self.plot( wave_bgs, flux_bgs, pen=pg.mkPen((31, 119, 180, 220), width=2), name="BGS (scaled)", antialias=True, ) gm_b = getattr(dual.bgs, "good_mask", None) if dual.bgs is not None else None if gm_b is not None and dual.bgs is not None and len(gm_b) == len(getattr(dual.bgs, "wave", [])): bw = np.asarray(dual.bgs.wave.value if hasattr(dual.bgs.wave, "value") else dual.bgs.wave, dtype=float) bf = np.asarray(dual.bgs.flux.value if hasattr(dual.bgs.flux, "value") else dual.bgs.flux, dtype=float) bs = float(getattr(dual, "arm_scale_bgs_to_rgs", 1.0)) gm_b = np.asarray(gm_b, dtype=bool) & np.isfinite(bw) & np.isfinite(bf) if np.any(gm_b): self.plot(bw[gm_b], bf[gm_b] * bs, pen=pg.mkPen((31, 119, 180, 240), width=2), antialias=True) parts_wave = [] parts_flux = [] if wave_rgs.size > 0: parts_wave.append(wave_rgs) parts_flux.append(flux_rgs) if wave_bgs.size > 0: parts_wave.append(wave_bgs) parts_flux.append(flux_bgs) if parts_wave: all_wave = np.concatenate(parts_wave) all_flux = np.concatenate(parts_flux) self.wave = all_wave self.flux = all_flux self._annotation_wave = all_wave self._annotation_flux = all_flux xlim_parts = [] if wave_rgs_for_xlim.size > 0: xlim_parts.append(wave_rgs_for_xlim) if wave_bgs_for_xlim.size > 0: xlim_parts.append(wave_bgs_for_xlim) if xlim_parts: all_wave_for_xlim = np.concatenate(xlim_parts) self._observed_wmin = float(np.nanmin(all_wave_for_xlim)) self._observed_wmax = float(np.nanmax(all_wave_for_xlim)) else: self._observed_wmin = float(np.nanmin(all_wave)) self._observed_wmax = float(np.nanmax(all_wave)) elif is_sparcl: idx = (wave_full >= 3800) & (wave_full <= 9800) wave_full = wave_full[idx] flux_full = flux_full[idx] if not is_euclid_dual: if is_euclid: finite = np.isfinite(wave_full) & np.isfinite(flux_full) wave = wave_full[finite] flux = flux_full[finite] else: flux_masked = np.ma.masked_invalid(flux_full) flux_sigclip = sigma_clip(flux_masked, sigma=10, maxiters=3) wave = wave_full[~flux_sigclip.mask] flux = flux_sigclip.data[~flux_sigclip.mask] self.wave = wave self.flux = flux annotation_wave_parts = [np.asarray(wave)] annotation_flux_parts = [np.asarray(flux)] if is_sparcl: if self.downsampling_enabled: downsample_item = self.plot(wave, flux, pen=pg.mkPen('k', width=2), antialias=True) downsample_item.setDownsampling(ds=3, auto=self._downsampling_auto, method=self._downsampling_method) downsample_item.setClipToView(self._downsampling_clip_to_view) else: self.plot(wave, flux, pen=pg.mkPen('k', width=1.5), antialias=True) euclid_object_id = getattr(spec, "euclid_object_id", None) dr_text = str(getattr(spec, "data_release", "") or "") is_desi_overlay = "desi" in dr_text.lower() overlay_good_pen = pg.mkPen((128, 0, 128, 220), width=2) if is_desi_overlay else pg.mkPen((0, 150, 0, 220), width=2) if self.euclid_fits is not None and euclid_object_id not in (None, "", 0): euclid_spec = self._load_euclid_overlay(euclid_object_id) if euclid_spec is not None: euclid_wave = euclid_spec.wave.value euclid_flux = euclid_spec.flux.value euclid_good_mask = getattr(euclid_spec, 'good_mask', None) scale = 1.0 denom_flux = euclid_flux if euclid_good_mask is not None and len(euclid_good_mask) == len(euclid_flux): good = np.asarray(euclid_good_mask, dtype=bool) good = good & np.isfinite(euclid_wave) & np.isfinite(euclid_flux) if np.any(good): denom_flux = euclid_flux[good] denom = np.nanmedian(np.abs(denom_flux)) numer = np.nanmedian(np.abs(flux)) if np.isfinite(denom) and denom > 0 and np.isfinite(numer) and numer > 0: scale = numer / denom euclid_flux_scaled = euclid_flux * scale self.plot( euclid_wave, euclid_flux_scaled, pen=pg.mkPen((95, 95, 95, 150), width=2), antialias=True, ) if euclid_good_mask is not None and len(euclid_good_mask) == len(euclid_flux): good = np.asarray(euclid_good_mask, dtype=bool) good = good & np.isfinite(euclid_wave) & np.isfinite(euclid_flux_scaled) if np.any(good): self.plot( euclid_wave[good], euclid_flux_scaled[good], pen=overlay_good_pen, antialias=True, ) annotation_wave_parts.append(np.asarray(euclid_wave)) annotation_flux_parts.append(np.asarray(euclid_flux_scaled)) try: self._observed_wmin = float(min(np.nanmin(wave), np.nanmin(euclid_wave))) self._observed_wmax = float(max(np.nanmax(wave), np.nanmax(euclid_wave))) except Exception: self._observed_wmin = float(np.nanmin(wave)) self._observed_wmax = float(np.nanmax(wave)) else: self._observed_wmin = float(np.nanmin(wave)) self._observed_wmax = float(np.nanmax(wave)) else: if is_euclid: self.plot(wave, flux, pen=pg.mkPen((95, 95, 95, 150), width=2), antialias=True) good_mask = getattr(spec, 'good_mask', None) if good_mask is not None and len(good_mask) == len(wave_full): good_mask = np.asarray(good_mask, dtype=bool) good_mask = good_mask & np.isfinite(wave_full) & np.isfinite(flux_full) if np.any(good_mask): wave_good = wave_full[good_mask] flux_good = flux_full[good_mask] self.plot(wave_good, flux_good, pen=pg.mkPen((0, 0, 180, 220), width=2), antialias=True) self.wave = wave_good self.flux = flux_good annotation_wave_parts.append(np.asarray(wave)) annotation_flux_parts.append(np.asarray(flux)) self._observed_wmin = float(np.nanmin(wave)) self._observed_wmax = float(np.nanmax(wave)) else: self.plot(wave, flux, pen='b', symbol='o', symbolSize=4, symbolPen=None, connect='finite', symbolBrush='k', antialias=True) self._observed_wmin = float(np.nanmin(wave)) self._observed_wmax = float(np.nanmax(wave)) if annotation_wave_parts and annotation_flux_parts: self._annotation_wave = np.concatenate(annotation_wave_parts) self._annotation_flux = np.concatenate(annotation_flux_parts) # Update labels with proper units if hasattr(spec, 'flux_unit') and spec.flux_unit is not None: flux_unit_str = f'Flux ({spec.flux_unit})' wave_unit_str = 'Wavelength (Å)' if hasattr(spec, 'wave_unit') and spec.wave_unit is not None: wave_unit_str = f'Wavelength ({spec.wave_unit})' # Plot template and always clip it to the observed wavelength span. template = self.template_manager.get_template(self.template_manager.current_template) if template is not None: z_vi = getattr(spec, 'z_vi', getattr(spec, 'redshift', 0.0)) # print(f"Plotting template '{self.template_manager.current_template}' with z_vi={z_vi:.4f}") if z_vi is None: z_vi = 0.0 wave_temp = template['wave'] * (1 + z_vi) flux_temp = template['flux'] if self._observed_wmin is not None and self._observed_wmax is not None: wmin = float(self._observed_wmin) wmax = float(self._observed_wmax) else: finite = np.isfinite(wave) & np.isfinite(flux) if np.any(finite): wmin = float(np.nanmin(wave[finite])) wmax = float(np.nanmax(wave[finite])) else: wmin = float(np.nanmin(wave)) wmax = float(np.nanmax(wave)) idx = (wave_temp >= wmin) & (wave_temp <= wmax) wave_temp = wave_temp[idx] flux_temp = flux_temp[idx] if wave_temp.size > 0 and np.isfinite(np.mean(flux_temp)) and np.mean(flux_temp) != 0: ref_flux = self.flux if hasattr(self, "flux") and self.flux is not None and len(self.flux) > 0 else flux_temp flux_temp_scaled = flux_temp / np.mean(flux_temp) * np.abs(np.nanmean(ref_flux)) * 1.5 self.plot(wave_temp, flux_temp_scaled, pen=pg.mkPen(_TEMPLATE_COLOR, width=2), antialias=True) self._label_template_emission_lines( wmin=wmin, wmax=wmax, z=z_vi, ) self.setLabel('left', flux_unit_str) self.setLabel('bottom', wave_unit_str) # Update info label above plot self.update_spectrum_info_label() self._apply_axes(preserve_view=preserve_view, is_sparcl=is_sparcl)
# Coordinate signal emission is now handled in navigation methods def _apply_axes(self, *, preserve_view=False, is_sparcl=False): if preserve_view and self._restore_locked_view(): return self._suspend_view_lock_updates = True try: if self._observed_wmin is not None and self._observed_wmax is not None: self.enableAutoRange(axis='x', enable=False) self.setXRange(float(self._observed_wmin), float(self._observed_wmax), padding=0.0) flux_arr = self._annotation_flux if self._annotation_flux is not None else getattr(self, 'flux', None) fixed_zero_ylim = None if flux_arr is not None: flux_arr = np.asarray(flux_arr, dtype=float) finite_flux = flux_arr[np.isfinite(flux_arr)] if finite_flux.size > 0 and np.allclose(finite_flux, 0.0, rtol=0.0, atol=0.0): ylim = 1e-16 if getattr(self.spec, 'flux_unit', None) is not None else 1.0 fixed_zero_ylim = (-ylim, ylim) if fixed_zero_ylim is not None: self.enableAutoRange(axis='y', enable=False) self.setYRange(float(fixed_zero_ylim[0]), float(fixed_zero_ylim[1]), padding=0.0) elif is_sparcl: flux_ref = np.asarray(getattr(self, 'flux', []), dtype=float) finite_flux = flux_ref[np.isfinite(flux_ref)] if finite_flux.size > 0: self.enableAutoRange(axis='y', enable=False) y1, y2 = np.percentile(finite_flux, [0.01, 99.99]) ymin = y1 - 0.05 * (y2 - y1) ymax = y2 + 0.05 * (y2 - y1) if ymin == ymax: pad = abs(ymin) * 0.1 if ymin != 0 else 1.0 ymin -= pad ymax += pad self.setYRange(ymin, ymax, padding=0.05) else: self.enableAutoRange(axis='y', enable=True) else: self.enableAutoRange(axis='y', enable=True) finally: self._suspend_view_lock_updates = False
[docs] def update_spectrum_info_label(self): """Update the spectrum info label above the plot.""" if not hasattr(self, 'spec'): return spec = self.spec z_vi = getattr(spec, 'z_vi', 0.0) z_gaia = getattr(spec, 'z_gaia', None) objname = getattr(spec, 'objname', 'Unknown') objid = getattr(spec, 'objid', 'Unknown') class_vi = None dual_diag = None record = self._history_record(spec=spec, create=False) if record is not None: class_vi = record.get("class_vi") if class_vi in (None, ""): class_vi = getattr(spec, "class_vi", None) def _fmt_z(label, value, *, hide_zero=True): if value is None: return None try: v = float(value) except Exception: return None if not np.isfinite(v): return None if hide_zero and v == 0.0: return None return f"{label} = {v:.4f}" # Calculate the display number based on which spectrum we're actually showing # In plot_next: counter gets incremented AFTER plotting, so counter+1 is the display number # In plot_previous: counter gets decremented BEFORE plotting, so counter is the display number # We need to determine what spectrum we're actually displaying if hasattr(self, '_displaying_spectrum_number'): current_spectrum_number = self._displaying_spectrum_number else: # Fallback to counter (this handles template updates and other cases) current_spectrum_number = self.counter if hasattr(self, 'len_list') else 1 message = f"Spectrum {current_spectrum_number}/{self.len_list}" if hasattr(self, 'len_list') else "" parts = [] if message: parts.append(message) display_objid = getattr(spec, "object_id", objid) if self._is_aimsz_review else objid parts.append(f"ID: {display_objid}") if ('SpecSparcl' in globals() and isinstance(spec, SpecSparcl)) or ( 'SpecAIMSZReview' in globals() and isinstance(spec, SpecAIMSZReview) ): targetid = getattr(spec, 'targetid', None) if targetid not in (None, "", 0): parts.append(f"targetid: {targetid}") parts.append(_fmt_z("z_vi", z_vi, hide_zero=False) or "z_vi = -") if class_vi not in (None, ""): parts.append(f"class_vi: {display_class_label(class_vi)}") if not self._is_aimsz_review: z_ref_str = _fmt_z("z_ref", getattr(spec, "z_ref", None), hide_zero=False) if z_ref_str is not None: parts.append(z_ref_str) if (('SpecSparcl' in globals() and isinstance(spec, SpecSparcl)) or ('SpecAIMSZReview' in globals() and isinstance(spec, SpecAIMSZReview))): dr = str(getattr(spec, 'data_release', '') or '') if 'desi' in dr.lower(): parts.append(_fmt_z("z_desi", getattr(spec, 'redshift', None), hide_zero=False) or "z_desi = -") z_gaia_str = _fmt_z("z_gaia", z_gaia, hide_zero=True) if z_gaia_str is not None: parts.append(z_gaia_str) if self.spec_dual is not None: dual = self.spec_dual scale = float(getattr(dual, "arm_scale_bgs_to_rgs", 1.0)) status = str(getattr(dual, "scale_status", "unknown")) owmin = getattr(dual, "overlap_wmin", np.nan) owmax = getattr(dual, "overlap_wmax", np.nan) n_bgs = int(getattr(dual, "overlap_n_bgs", 0)) n_rgs = int(getattr(dual, "overlap_n_rgs", 0)) if np.isfinite(owmin) and np.isfinite(owmax) and owmax > owmin: dual_diag = ( f"BGS→RGS scale={scale:.4g} ({status}), " f"overlap={owmin:.1f}-{owmax:.1f} Å, nB={n_bgs}, nR={n_rgs}" ) else: dual_diag = f"BGS→RGS scale={scale:.4g} ({status}), overlap=none" text_content = " ".join(parts) if dual_diag: text_content = f"{text_content}\n{dual_diag}" if hasattr(self, 'spectrum_info_label'): self.spectrum_info_label.setText(text_content)
[docs] def plot_template(self): """Plot template with current redshift.""" template = self.template_manager.get_template(self.template_manager.current_template) if template is None: return z = getattr(self.spec, 'z_vi', getattr(self.spec, 'redshift', 0.0)) if z is None: z = 0.0 wave_shifted = template['wave'] * (1 + z) flux_template = template['flux'] if self._observed_wmin is not None and self._observed_wmax is not None: wmin = float(self._observed_wmin) wmax = float(self._observed_wmax) else: wave_full = self.spec.wave.value if hasattr(self.spec.wave, 'value') else self.spec.wave flux_full = self.spec.flux.value if hasattr(self.spec.flux, 'value') else self.spec.flux finite = np.isfinite(wave_full) & np.isfinite(flux_full) if np.any(finite): wmin = float(np.nanmin(wave_full[finite])) wmax = float(np.nanmax(wave_full[finite])) else: wmin = float(np.nanmin(wave_full)) wmax = float(np.nanmax(wave_full)) idx = (wave_shifted >= wmin) & (wave_shifted <= wmax) wave_shifted = wave_shifted[idx] flux_template = flux_template[idx] if wave_shifted.size == 0 or not np.isfinite(np.mean(flux_template)) or np.mean(flux_template) == 0: return # Scale template like in original code (unclipped on both sides as requested) if hasattr(self, 'flux') and len(self.flux) > 0: # Use the cleaned flux from plot_single for scaling flux_scaled = flux_template / np.mean(flux_template) * np.abs(self.flux.mean()) * 1.5 else: # Fallback if cleaned flux not available spec_flux = self.spec.flux.value if hasattr(self.spec.flux, 'value') else self.spec.flux flux_scaled = flux_template / np.mean(flux_template) * np.abs(np.nanmean(spec_flux)) * 1.5 # Plot template clipped to the observed wavelength range. self.plot(wave_shifted, flux_scaled, pen=pg.mkPen(_TEMPLATE_COLOR, width=2), antialias=True) if self._observed_wmin is not None and self._observed_wmax is not None: self._label_template_emission_lines(wmin=self._observed_wmin, wmax=self._observed_wmax, z=z) elif hasattr(self, 'wave') and self.wave is not None and len(self.wave) > 0: self._label_template_emission_lines(wmin=float(np.nanmin(self.wave)), wmax=float(np.nanmax(self.wave)), z=z) # Update info label to reflect new redshift self.update_spectrum_info_label()
# Do NOT auto-range during template updates - preserves x-axis range def _load_euclid_overlay(self, euclid_object_id): if euclid_object_id is None: return None try: if isinstance(euclid_object_id, float) and np.isnan(euclid_object_id): return None except Exception: pass key = euclid_object_id if isinstance(key, (np.integer, int)): key = int(key) elif isinstance(key, (np.floating, float)): if float(key).is_integer(): key = int(key) key = str(key).strip() if not key or key.lower() == "nan": return None if key in self._euclid_overlay_cache: return self._euclid_overlay_cache[key] try: sp = SpecEuclid1d(self.euclid_fits, extname=key) except Exception as e: print(f"Euclid overlay load failed for extname={key}: {e}") self._euclid_overlay_cache[key] = None return None self._euclid_overlay_cache[key] = sp return sp def _label_template_emission_lines(self, *, wmin, wmax, z): """Overlay emission-line markers for the template within [wmin, wmax].""" if not np.isfinite(wmin) or not np.isfinite(wmax) or wmax <= wmin: return try: z = float(z) except Exception: return if not np.isfinite(z) or z < 0: return # Use current (cleaned) spectrum flux to place labels near the top. y_ref = None if hasattr(self, 'flux') and self.flux is not None and len(self.flux) > 0: try: y_ref = float(np.nanmax(self.flux)) except Exception: y_ref = None if y_ref is None or not np.isfinite(y_ref): y_ref = 0.0 y_base = y_ref * 0.92 if y_ref != 0 else 0.0 pen = pg.mkPen((255, 0, 0), width=2, style=Qt.DashLine) text_color = (80, 80, 80) lines = _TEMPLATE_EMISSION_LINES # Stagger labels to reduce overlap. y_offsets = [0.0, 0.06, 0.12] k = 0 for name, rest_aa in lines: x = rest_aa * (1.0 + z) if x < wmin or x > wmax: continue self.addItem(pg.InfiniteLine(pos=x, angle=90, pen=pen, movable=False)) y = y_base * (1.0 - y_offsets[k % len(y_offsets)]) if y_base != 0 else 0.0 label = pg.TextItem(text=str(name), color=text_color, anchor=(0.5, 1.0)) label.setPos(x, y) self.addItem(label) k += 1
[docs] def plot_next(self): """Plot next spectrum.""" if self.counter >= self.len_list: print("No more spectra to plot.") return self.clear() self._clear_view_lock() while self.counter < self.len_list: try: spec = self._load_current_spec(self.counter) self.spec = spec break except Exception as e: shown_idx = self.counter + 1 print(f"Failed to load spectrum {shown_idx}: {e}") self.counter += 1 else: print("No more spectra to plot.") return record = self._apply_history_to_spec(spec) if record is not None: class_vi = record.get("class_vi", "") print(f"\tVisual class from history: {class_vi}.") self.update_slider_and_spin() # Set the display number before plotting (counter + 1 because we haven't incremented yet) self._displaying_spectrum_number = self.counter + 1 self.plot_single() # Emit coordinate change signal for cutout loading if hasattr(self.spec, 'ra') and hasattr(self.spec, 'dec'): self.coordinate_changed.emit(self.spec.ra, self.spec.dec) self.current_spec_changed.emit() self.counter += 1
[docs] def plot_previous(self): """Plot previous spectrum.""" if self.counter > 1: self.clear() self._clear_view_lock() loaded = False target = self.counter - 2 while target >= 0: try: spec = self._load_current_spec(target) self.spec = spec loaded = True break except Exception as e: print(f"Failed to load spectrum {target + 1}: {e}") target -= 1 if not loaded: print("No previous spectrum to plot.") return record = self._apply_history_to_spec(spec) if record is not None: class_vi = record.get("class_vi", "") print(f"\tVisual class from history: {class_vi}.") self.counter = target + 1 self.update_slider_and_spin() # Set the display number before plotting (counter is correct after decrement) self._displaying_spectrum_number = self.counter self.plot_single() # Emit coordinate change signal for cutout loading if hasattr(self.spec, 'ra') and hasattr(self.spec, 'dec'): self.coordinate_changed.emit(self.spec.ra, self.spec.dec) self.current_spec_changed.emit() else: print("No previous spectrum to plot.")
[docs] def jump_to_spectrum(self, index_one_based): """Jump directly to a spectrum by 1-based index.""" try: index_one_based = int(index_one_based) except Exception: print(f"Invalid index: {index_one_based}") return if index_one_based < 1 or index_one_based > self.len_list: print(f"Index out of range: {index_one_based}. Valid range is 1..{self.len_list}.") return self.clear() self._clear_view_lock() target_zero_based = index_one_based - 1 try: spec = self._load_current_spec(target_zero_based) except Exception as e: print(f"Failed to load spectrum {index_one_based}: {e}") return self.spec = spec record = self._apply_history_to_spec(spec) if record is not None: class_vi = record.get("class_vi", "") print(f"\tVisual class from history: {class_vi}.") self.counter = index_one_based self.update_slider_and_spin() self._displaying_spectrum_number = index_one_based self.plot_single() if hasattr(self.spec, 'ra') and hasattr(self.spec, 'dec'): self.coordinate_changed.emit(self.spec.ra, self.spec.dec) self.current_spec_changed.emit()
[docs] def change_template(self, template_name): """Change current template.""" if template_name in self.template_manager.get_available_templates(): self.template_manager.current_template = template_name self.clear() self.plot_single()
def _annotate_at_wave(self, wave_pos): """Annotate the nearest plotted point to ``wave_pos``.""" wave_arr = self._annotation_wave if self._annotation_wave is not None else getattr(self, 'wave', None) flux_arr = self._annotation_flux if self._annotation_flux is not None else getattr(self, 'flux', None) if wave_arr is None or flux_arr is None: return finite = np.isfinite(wave_arr) & np.isfinite(flux_arr) if not np.any(finite): return wave_fin = wave_arr[finite] flux_fin = flux_arr[finite] idx = np.abs(wave_fin - wave_pos).argmin() wave_val = wave_fin[idx] flux_val = flux_fin[idx] annotation_text = pg.TextItem( text="Wavelength: {0:.2f} Flux: {1:.2e}".format(wave_val, flux_val), anchor=(0, 0), color='r', border='w', fill=(255, 255, 255, 200)) annotation_text.setFont(QFont("Arial", 18, QFont.Bold)) annotation_text.setPos(wave_val, flux_val) self.addItem(annotation_text) print("Wavelength: {0:.2f} Flux: {1:.2e}".format(wave_val, flux_val))
[docs] def keyPressEvent(self, event): """Handle keyboard events.""" spec = self.spec self._last_committed_objid = None def _commit_current(class_vi=None, z_vi=None): if class_vi is not None: record = self._set_classification(spec, class_vi, z_vi) else: record = self._history_record(spec=spec, create=True) record["z_vi"] = getattr(spec, "z_vi", 0.0) if z_vi is None else z_vi record["qa_flag"] = int(getattr(spec, "qa_flag", 0) or 0) self._last_committed_objid = spec.objid self.record_committed.emit(spec.objid) return record if event.key() == Qt.Key_Q: if spec.objid not in self.history: _commit_current(self._default_class_token(), spec.z_vi) else: _commit_current(None, spec.z_vi) if self.counter < self.len_list: self.clear() self.plot_next() else: print("No more spectra to plot.") # Temp save every 50 spectra like original if (self.counter-1) % 50 == 0: print("Saving temp file to csv (n={})...".format(self.counter)) temp_filename = f"vi_temp_{self.counter-1}.csv" df_new = pd.DataFrame(self._history_rows_for_csv()) df_new.to_csv(temp_filename, index=False) elif event.key() == Qt.Key_S: print("\tClass: STAR.") _commit_current('STAR', 0.0) self.update_spectrum_info_label() elif event.key() == Qt.Key_G: print("\tClass: GALAXY.") _commit_current('GALAXY', spec.z_vi) self.update_spectrum_info_label() elif event.key() == Qt.Key_A: print("\tClass: QSO.") _commit_current('QSO', spec.z_vi) self.update_spectrum_info_label() elif event.key() == Qt.Key_N: print("\tClass: QSO(Narrow).") _commit_current('QSO_NARROW', spec.z_vi) self.update_spectrum_info_label() elif event.key() == Qt.Key_B: print("\tClass: QSO(BAL).") _commit_current('QSO_BAL', spec.z_vi) self.update_spectrum_info_label() elif event.key() == Qt.Key_F: print("\tClass: QSO(FeLoBAL).") _commit_current('QSO_FELOBAL', spec.z_vi) self.update_spectrum_info_label() elif event.key() == Qt.Key_U: print("\tClass: UNKNOWN.") _commit_current('UNKNOWN', 0.0) self.update_spectrum_info_label() elif event.key() == Qt.Key_D: print("\tClass: BAD spectrum.") _commit_current('BAD', 0.0) self.update_spectrum_info_label() elif event.key() == Qt.Key_L: print("\tClass: LIKELY/Unusual QSO.") _commit_current('LIKELY_Q', spec.z_vi) self.update_spectrum_info_label() if event.key() == Qt.Key_R and not (event.modifiers() & Qt.ControlModifier): self._clear_view_lock() self.clear() self.plot_single() if event.modifiers() & Qt.ControlModifier: # Mouse position like original mouse_pos = self.mapFromGlobal(QCursor.pos()) vb = self.getViewBox() mouse_pos = vb.mapSceneToView(mouse_pos) print(mouse_pos) elif event.key() == Qt.Key_Space: # Annotate spectrum at mouse position like original mouse_pos = self.mapFromGlobal(QCursor.pos()) vb = self.getViewBox() wave_pos = vb.mapSceneToView(mouse_pos).x() self._annotate_at_wave(wave_pos) if event.modifiers() & Qt.ControlModifier: if event.key() == Qt.Key_R: self.clear() self._clear_view_lock() # Reload current spectrum using original logic try: spec = self._load_current_spec(self.counter - 1) except Exception as e: print(f"Failed to reload spectrum {self.counter}: {e}") return self.spec = spec # For reload, display number is current counter self._displaying_spectrum_number = self.counter self.update_slider_and_spin() self.plot_single() elif event.key() == Qt.Key_Right: self.clear() self.counter = self.len_list - 1 self.plot_next() elif event.key() == Qt.Key_Left: self.clear() self.counter = 0 self.plot_next() elif event.key() == Qt.Key_B: self.clear() self.counter = len(self.history) - 1 self.plot_next() elif event.key() == Qt.Key_Left: self.plot_previous() elif event.key() == Qt.Key_Right: self.plot_next() elif event.key() == Qt.Key_M: mouse_pos = self.mapFromGlobal(QCursor.pos()) self.vb = self.getViewBox() mouse_pos = self.vb.mapSceneToView(mouse_pos) print(f"Mouse position - Wavelength: {mouse_pos.x():.2f}, Flux: {mouse_pos.y():.2e}") elif event.key() == Qt.Key_Space: mouse_pos = self.mapFromGlobal(QCursor.pos()) self.vb = self.getViewBox() wave = self.vb.mapSceneToView(mouse_pos).x() self._annotate_at_wave(wave)
[docs] class PGSpecPlotAppEnhanced(QApplication): """Enhanced standalone application with image cutouts and template switching.""" @staticmethod def _normalize_objid(value): """Normalize objid loaded from CSV. Keeps integers as int (for legacy FITS workflows) and keeps non-numeric IDs (e.g. SPARCL UUIDs) as str. """ if value is None or (isinstance(value, float) and np.isnan(value)): return value if isinstance(value, (np.integer, int)): return int(value) if isinstance(value, (np.floating, float)): if float(value).is_integer(): return int(value) return str(value) s = str(value) try: return int(s) except Exception: return s @staticmethod def _normalize_qa_flag(value): """Normalize qa_flag loaded from CSV/history.""" if value is None or (isinstance(value, float) and np.isnan(value)): return 0 try: return int(value) except Exception: return 0 @staticmethod def _clean_scalar(value): if value is None: return None try: if pd.isna(value): return None except Exception: pass if isinstance(value, str): text = value.strip() return None if text == "" else text return value @classmethod def _is_aimsz_review_class(cls, spec_class): try: return issubclass(spec_class, SpecAIMSZReview) except Exception: return False @classmethod def _is_sparcl_class(cls, spec_class): try: return issubclass(spec_class, SpecSparcl) except Exception: return False @classmethod def _normalize_loaded_objid(cls, row, *, is_aimsz_review=False, is_sparcl=False, aimsz_lookup=None): if is_aimsz_review: object_id = cls._clean_scalar(row.get("object_id", None)) if object_id is not None: return SpecAIMSZReview._canonical_objid(object_id) raw_objid = cls._clean_scalar(row.get("objid", None)) if raw_objid is not None: raw_text = str(raw_objid).strip() if raw_text.startswith("aimsz:"): return raw_text if raw_text.startswith("targetid"): suffix = raw_text[len("targetid"):].strip() if suffix: return SpecAIMSZReview._canonical_objid(suffix) return SpecAIMSZReview._canonical_objid(raw_text) targetid = cls._clean_scalar(row.get("targetid", None)) if targetid is not None: return SpecAIMSZReview._canonical_objid(str(targetid).strip()) return None if is_sparcl: sparcl_id = cls._clean_scalar(row.get("sparcl_id", None)) targetid = cls._clean_scalar(row.get("targetid", None)) specid = cls._clean_scalar(row.get("specid", None)) raw_objid = cls._clean_scalar(row.get("objid", None)) if isinstance(raw_objid, str): raw_text = raw_objid.strip() if raw_text.startswith(("sparcl:", "targetid:", "specid:", "sparcl-row:")): return raw_text return SpecSparcl._canonical_objid( sparcl_id=sparcl_id, targetid=targetid if targetid is not None else raw_objid, specid=specid if specid is not None else raw_objid, filename=None, row=0, ) raw_objid = cls._clean_scalar(row.get("objid", None)) return cls._normalize_objid(raw_objid) @classmethod def _row_to_history_record(cls, row, *, is_aimsz_review=False): return { "objname": row.get("objname", "Unknown"), "ra": row.get("ra", np.nan), "dec": row.get("dec", np.nan), "class_vi": normalize_class_label(row.get("class_vi", "")), "z_vi": row.get("z_vi", np.nan), "targetid": row.get("targetid", None), "data_release": normalize_data_release(row.get("data_release", None), aimsz_review=is_aimsz_review), "qa_flag": cls._normalize_qa_flag(row.get("qa_flag", 0)), "notes": "" if not is_aimsz_review else str(row.get("notes", "") or ""), "reviewer": "" if not is_aimsz_review else str(row.get("reviewer", "") or default_reviewer_username()), "reviewed_at": "" if not is_aimsz_review else str(row.get("reviewed_at", "") or ""), "object_id": row.get("object_id", None), } @staticmethod def _now_reviewed_at(): return datetime.now(timezone.utc).replace(microsecond=0).isoformat() def __init__(self, spectra, SpecClass=SpecEuclid1d, output_file='vi_output.csv', z_max=5.0, load_history=False, euclid_fits=None, cutout_buffer_dir=None, enable_image_panel=None, enable_background_prefetch=None, rgs_file=None, bgs_file=None, ext=None, extname=None, dual_good_pixels_only=False, external_redshift_lookup=None, external_redshift_key="object_id"): super().__init__(sys.argv) self.output_file = output_file self.spectra = spectra self.SpecClass = SpecClass self._is_aimsz_review = self._is_aimsz_review_class(self.SpecClass) self._is_sparcl = self._is_sparcl_class(self.SpecClass) self._session_reviewer = default_reviewer_username() self.euclid_fits = euclid_fits if enable_image_panel is None: enable_image_panel = False self.enable_image_panel = bool(enable_image_panel) if enable_background_prefetch is None: enable_background_prefetch = self.enable_image_panel self.enable_background_prefetch = bool(enable_background_prefetch) and self.enable_image_panel self.rgs_file = rgs_file self.bgs_file = bgs_file self.dual_ext = ext self.dual_extname = extname self.dual_good_pixels_only = bool(dual_good_pixels_only) self._dual_mode = bool(self.rgs_file and self.bgs_file) self.external_redshift_lookup = dict(external_redshift_lookup or {}) self.external_redshift_key = str(external_redshift_key or "object_id") if load_history and os.path.exists(self.output_file): print(f"Loading history from {self.output_file} ...") df = pd.read_csv(self.output_file) if 'vi_class' in df.columns: df.rename(columns={'vi_class': 'class_vi'}, inplace=True) history_dict = {} for _, row in df.iterrows(): objid = self._normalize_loaded_objid( row, is_aimsz_review=self._is_aimsz_review, is_sparcl=self._is_sparcl, ) if objid is None: continue record = self._row_to_history_record( row, is_aimsz_review=self._is_aimsz_review, ) history_dict[objid] = record if self._is_aimsz_review and record.get("reviewer"): self._session_reviewer = str(record.get("reviewer")) initial_counter = df.shape[0] else: history_dict = {} initial_counter = 0 self.plot = PGSpecPlotEnhanced( self.spectra, self.SpecClass, initial_counter=initial_counter, z_max=z_max, history_dict=history_dict, euclid_fits=self.euclid_fits, rgs_file=self.rgs_file, bgs_file=self.bgs_file, ext=self.dual_ext, extname=self.dual_extname, dual_good_pixels_only=self.dual_good_pixels_only, external_redshift_lookup=self.external_redshift_lookup, external_redshift_key=self.external_redshift_key) self.plot._session_reviewer = self._session_reviewer self.len_list = self.plot.len_list self.plot.current_spec_changed.connect(self.on_current_spec_changed) self.plot.record_committed.connect(self.on_record_committed) self._last_review_panel_objid = None self._last_review_panel_state = None self.cutout_widget = None self.review_panel = None if self._is_aimsz_review: self.review_panel = AIMSZReviewPanel(default_reviewer=self._session_reviewer) self.review_panel.qa_flag_changed.connect(self.on_review_panel_qa_changed) self.review_panel.notes_changed.connect(self.on_review_notes_changed) self.review_panel.reviewer_changed.connect(self.on_review_reviewer_changed) if self.enable_image_panel: if cutout_buffer_dir is not None: buffer_dir = Path(cutout_buffer_dir) elif self._dual_mode: buffer_dir = Path(self.rgs_file).parent / "cutout_buffer" elif isinstance(spectra, str): # Single FITS file buffer_dir = Path(spectra).parent / "cutout_buffer" else: # List of files buffer_dir = Path(spectra[0]).parent / "cutout_buffer" self.cutout_widget = ImageCutoutWidget(buffer_dir=buffer_dir) if self._is_aimsz_review and hasattr(self.cutout_widget, "qa_group"): self.cutout_widget.qa_group.setVisible(False) self.cutout_widget.qa_contamination_cb.stateChanged.connect(self.on_qa_flag_changed) self.cutout_widget.qa_unusable_cb.stateChanged.connect(self.on_qa_flag_changed) self.plot.coordinate_changed.connect(self.on_coordinate_changed) if hasattr(self.plot, 'spec') and hasattr(self.plot.spec, 'ra') and hasattr(self.plot.spec, 'dec'): objid = getattr(self.plot.spec, 'objid', None) self.cutout_widget.load_online_cutouts(self.plot.spec.ra, self.plot.spec.dec, objid) self.sync_qa_checkbox_from_current_spec() # Start background prefetching for next objects if self.enable_background_prefetch: self.start_background_prefetch() self.make_layout() self.aboutToQuit.connect(self.save_dict_todf)
[docs] def on_coordinate_changed(self, ra, dec): """Handle coordinate changes and pass object ID to cutout widget.""" if self.cutout_widget is None: return objid = getattr(self.plot.spec, 'objid', None) if hasattr(self.plot, 'spec') else None self.cutout_widget.load_online_cutouts(ra, dec, objid) self.sync_qa_checkbox_from_current_spec()
[docs] def on_current_spec_changed(self): if self._is_aimsz_review and hasattr(self.plot, "spec"): spec = self.plot.spec if not getattr(spec, "reviewer", ""): spec.reviewer = self._session_reviewer self.sync_qa_checkbox_from_current_spec() self._update_go_to_controls() self._update_downsample_toggle()
[docs] def on_record_committed(self, objid): if not self._is_aimsz_review: return record = self.plot._history_record(objid=objid, create=False) if record is None: return record["reviewed_at"] = self._now_reviewed_at() if getattr(self.plot, "spec", None) is not None and getattr(self.plot.spec, "objid", None) == objid: self.plot.spec.reviewed_at = record["reviewed_at"] self._sync_review_panel(force=True)
def _commit_current_review_state(self, update_timestamp=False): if not self._is_aimsz_review or not hasattr(self.plot, "spec"): return spec = self.plot.spec record = self.plot._history_record(spec=spec, create=True) record["z_vi"] = getattr(spec, "z_vi", record.get("z_vi", 0.0)) record["qa_flag"] = int(getattr(spec, "qa_flag", record.get("qa_flag", 0)) or 0) record["notes"] = str(getattr(spec, "notes", record.get("notes", "")) or "") record["reviewer"] = str( getattr(spec, "reviewer", record.get("reviewer", "")) or self._session_reviewer ) self._session_reviewer = record["reviewer"] self.plot._session_reviewer = self._session_reviewer if self.review_panel is not None: self.review_panel.set_default_reviewer(self._session_reviewer) spec.reviewer = record["reviewer"] if not record.get("class_vi"): record["class_vi"] = self.plot._default_class_token() if update_timestamp: record["reviewed_at"] = self._now_reviewed_at() spec.reviewed_at = record["reviewed_at"] self.plot.history[spec.objid] = record self._sync_review_panel(force=True) def _review_panel_state(self, spec, record): if spec is None: return None record = record or {} return ( getattr(spec, "objid", None), int(record.get("qa_flag", getattr(spec, "qa_flag", 0)) or 0), str(record.get("notes", getattr(spec, "notes", "")) or ""), str(record.get("reviewer", getattr(spec, "reviewer", "")) or ""), str(record.get("reviewed_at", getattr(spec, "reviewed_at", "")) or ""), getattr(spec, "review_priority_tier", None), getattr(spec, "review_score", None), getattr(spec, "review_rank_within_tier", None), getattr(spec, "review_slice_label", None), getattr(spec, "z_ref", None), getattr(spec, "z_ml_expect", None), getattr(spec, "z_pcf_best", None), getattr(spec, "pcf_template_best", None), getattr(spec, "pcf_score_best", None), ) def _sync_review_panel(self, force=False): if self.review_panel is None or not hasattr(self.plot, "spec"): return spec = self.plot.spec record = self.plot._history_record(spec=spec, objid=getattr(spec, "objid", None), create=False) state = self._review_panel_state(spec, record) if not force and state == self._last_review_panel_state: return self.review_panel.set_review_context(spec, record) self._last_review_panel_objid = getattr(spec, "objid", None) self._last_review_panel_state = state
[docs] def sync_qa_checkbox_from_current_spec(self): """Restore QA checkboxes from current spec/history.""" if not hasattr(self.plot, 'spec'): return spec = self.plot.spec objid = getattr(spec, 'objid', None) record = self.plot._history_record(spec=spec, objid=objid, create=False) qa_flag = self._normalize_qa_flag(record.get("qa_flag", getattr(spec, 'qa_flag', 0)) if record else getattr(spec, 'qa_flag', 0)) spec.qa_flag = qa_flag if self.cutout_widget is not None: self.cutout_widget.set_qa_flag(qa_flag) self._sync_review_panel(force=False)
[docs] def on_qa_flag_changed(self, _state): """Persist QA flag from checkbox selections into current spec/history.""" if self.cutout_widget is None: return if not hasattr(self.plot, 'spec'): return spec = self.plot.spec qa_flag = self.cutout_widget.get_qa_flag() spec.qa_flag = qa_flag objid = getattr(spec, 'objid', None) if objid is None: return record = self.plot._history_record(spec=spec, objid=objid, create=True) if not record.get("class_vi"): record["class_vi"] = self.plot._default_class_token() record["qa_flag"] = qa_flag self.plot.history[objid] = record self._sync_review_panel(force=True)
[docs] def on_review_panel_qa_changed(self, qa_flag): if not hasattr(self.plot, "spec"): return spec = self.plot.spec spec.qa_flag = int(qa_flag) objid = getattr(spec, "objid", None) if objid is None: return record = self.plot._history_record(spec=spec, objid=objid, create=True) if not record.get("class_vi"): record["class_vi"] = self.plot._default_class_token() record["qa_flag"] = int(qa_flag) self.plot.history[objid] = record if self.cutout_widget is not None: self.cutout_widget.set_qa_flag(int(qa_flag))
[docs] def on_review_notes_changed(self, text): if not hasattr(self.plot, "spec"): return spec = self.plot.spec spec.notes = text record = self.plot._history_record(spec=spec, create=True) record["notes"] = text if not record.get("class_vi"): record["class_vi"] = self.plot._default_class_token() self.plot.history[spec.objid] = record
[docs] def on_review_reviewer_changed(self, text): if not hasattr(self.plot, "spec"): return spec = self.plot.spec reviewer = str(text or "").strip() or self._session_reviewer spec.reviewer = reviewer self._session_reviewer = reviewer self.plot._session_reviewer = self._session_reviewer if self.review_panel is not None: self.review_panel.set_default_reviewer(self._session_reviewer) record = self.plot._history_record(spec=spec, create=True) record["reviewer"] = reviewer if not record.get("class_vi"): record["class_vi"] = self.plot._default_class_token() self.plot.history[spec.objid] = record self._sync_review_panel(force=False)
[docs] def start_background_prefetch(self): """Start background prefetching of cutouts for upcoming spectra.""" def prefetch_worker(): try: # Load upcoming spectra data for prefetching current_index = getattr(self.plot, 'counter', 0) spectra_data = [] # Get ALL remaining spectra data for i in range(current_index, self.len_list): try: if self._dual_mode: dual = SpecEuclid1dDual( rgs_file=self.rgs_file, bgs_file=self.bgs_file, ext=self.dual_ext if self.dual_ext is not None else i + 1, extname=self.dual_extname, good_pixels_only=self.dual_good_pixels_only, ) spec = dual.rgs if dual.rgs is not None else dual.bgs if spec is None: continue elif self.plot.speclist is not None: # Load from individual files filename = self.plot.speclist[i] spec = self.SpecClass(filename) else: # Load from multi-extension FITS spec = self.SpecClass(self.spectra, ext=i + 1) self.plot._ensure_spec_defaults(spec) # Only prefetch if has coordinates if hasattr(spec, 'ra') and hasattr(spec, 'dec') and hasattr(spec, 'objid'): spectra_data.append(spec) except Exception as e: print(f"Error loading spectrum {i} for prefetch: {e}") continue # Trigger prefetching if we have data if spectra_data: self.cutout_widget.prefetch_cutouts_background(spectra_data, 0, None) except Exception as e: print(f"Error in background prefetch setup: {e}") # Run in background thread threading.Thread(target=prefetch_worker, daemon=True).start()
def _update_downsample_toggle(self): if not hasattr(self, "downsample_toggle"): return enabled = self.plot._is_current_sparcl_like() self.downsample_toggle.blockSignals(True) self.downsample_toggle.setChecked(bool(self.plot.downsampling_enabled)) self.downsample_toggle.setEnabled(enabled) self.downsample_toggle.blockSignals(False)
[docs] def on_downsample_toggled(self, checked): self.plot.downsampling_enabled = bool(checked) if not hasattr(self.plot, "spec") or not self.plot._is_current_sparcl_like(): return self.plot.clear() self.plot.plot_single(preserve_view=self.plot._view_lock_active) self.plot.setFocus()
[docs] def make_layout(self): """Create the enhanced layout with image cutouts and controls.""" layout = pg.LayoutWidget() layout.resize(1300, 900) # Reasonable width with text wrapping layout.setWindowTitle(f"PGSpecPlot Enhanced - Spectra Viewer (v{viewer_version})") if self.plot.counter < self.len_list + 1: # Create toolbar with template controls and save buttons toolbar = QWidget() toolbar_layout = QHBoxLayout() # Template selection template_group = QGroupBox("Template") template_group.setMaximumHeight(45) # Make template group more compact template_layout = QHBoxLayout() template_layout.setContentsMargins(5, 2, 5, 2) # Reduce margins self.template_buttons = QButtonGroup() template_names = self.plot.template_manager.get_available_templates() for i, template_name in enumerate(template_names): btn = QRadioButton(template_name) if template_name == "Type 1": btn.setChecked(True) btn.clicked.connect(lambda checked, name=template_name: self.plot.change_template(name) if checked else None) self.template_buttons.addButton(btn) template_layout.addWidget(btn) template_group.setLayout(template_layout) toolbar_layout.addWidget(template_group) if self.enable_image_panel and self.cutout_widget is not None: self.image_toggle_btn = QPushButton("Hide Images") self.image_toggle_btn.clicked.connect(self.toggle_image_panel) self.image_toggle_btn.setCheckable(True) self.image_toggle_btn.setMaximumHeight(35) # Make button more compact toolbar_layout.addWidget(self.image_toggle_btn) self.downsample_toggle = QCheckBox("Downsample (n=3)") self.downsample_toggle.setMaximumHeight(35) self.downsample_toggle.setToolTip( "Use pyqtgraph native downsampling for SPARCL/AIMS-z spectra " "(mean + auto + clip-to-view)." ) self.downsample_toggle.stateChanged.connect(self.on_downsample_toggled) toolbar_layout.addWidget(self.downsample_toggle) toolbar_layout.addWidget(QLabel("Go to index:")) self.goto_index_spin = QSpinBox() self.goto_index_spin.setMinimum(1) self.goto_index_spin.setMaximum(self.len_list) self.goto_index_spin.setValue(1) self.goto_index_spin.setMaximumHeight(35) # Avoid stealing keyboard navigation shortcuts on startup. self.goto_index_spin.setFocusPolicy(Qt.ClickFocus) toolbar_layout.addWidget(self.goto_index_spin) self.goto_index_btn = QPushButton("Go") self.goto_index_btn.setMaximumHeight(35) self.goto_index_btn.clicked.connect(self.go_to_index) toolbar_layout.addWidget(self.goto_index_btn) # Add spacer toolbar_layout.addStretch() # Save buttons self.save_png_btn = QPushButton("Save PNG") self.save_png_btn.clicked.connect(self.save_png) self.save_png_btn.setMaximumHeight(35) toolbar_layout.addWidget(self.save_png_btn) self.save_btn = QPushButton("Save") self.save_btn.clicked.connect(self.save_data) self.save_btn.setMaximumHeight(35) # Make button more compact toolbar_layout.addWidget(self.save_btn) self.save_quit_btn = QPushButton("Save & Quit") self.save_quit_btn.clicked.connect(self.save_and_quit) self.save_quit_btn.setMaximumHeight(35) # Make button more compact toolbar_layout.addWidget(self.save_quit_btn) toolbar.setLayout(toolbar_layout) toolbar.setMaximumHeight(50) # Limit toolbar height toolbar.setMinimumHeight(40) # Set minimum height layout.addWidget(toolbar, row=0, col=0, colspan=2) # Instructions with comprehensive keyboard shortcuts instruction_text = ( "Navigation: 'Q' next spectrum, Left/Right arrows previous/next | " "Classification: 'A' QSO, 'N' QSO(Narrow), 'B' QSO(BAL), 'F' QSO(FeLoBAL), 'D' BAD spectrum, | " "'S' STAR, 'G' GALAXY, 'U' UNKNOWN, 'L' LIKELY_Q | " "Tools: 'Space' wavelength info, 'M' mouse position, 'R' reset zoom | " "Advanced: Ctrl+R reload, Ctrl+Left first, Ctrl+Right last, Ctrl+B resume from history" ) toplabel = layout.addLabel(instruction_text, row=1, col=0, colspan=2) toplabel.setFont(QFont("Arial", 15)) toplabel.setMinimumHeight(72) toplabel.setMaximumHeight(96) toplabel.setAlignment(Qt.AlignLeft | Qt.AlignTop) toplabel.setStyleSheet("background-color: white;color: black;") toplabel.setFrameStyle(QFrame.Panel | QFrame.Raised) toplabel.setWordWrap(True) # Main content area with splitter main_splitter = QSplitter(Qt.Horizontal) # Left side: spectrum plot and controls left_widget = QWidget() left_layout = QVBoxLayout() # Add spectrum info label above plot left_layout.addWidget(self.plot.spectrum_info_label) left_layout.addWidget(self.plot) # Redshift slider slider_container = QWidget() slider_layout = QHBoxLayout() slider_layout.addWidget(QLabel("Redshift:")) slider_layout.addWidget(self.plot.slider) slider_layout.addWidget(self.plot.redshiftSpin) slider_container.setLayout(slider_layout) left_layout.addWidget(slider_container) left_widget.setLayout(left_layout) main_splitter.addWidget(left_widget) right_widget = None if self.review_panel is not None or (self.enable_image_panel and self.cutout_widget is not None): right_widget = QWidget() right_layout = QVBoxLayout() if self.review_panel is not None: right_layout.addWidget(self.review_panel) if self.enable_image_panel and self.cutout_widget is not None: right_layout.addWidget(self.cutout_widget) right_layout.addStretch() right_widget.setLayout(right_layout) main_splitter.addWidget(right_widget) main_splitter.setSizes(self._default_splitter_sizes()) main_splitter.setStretchFactor(0, 5 if self._is_aimsz_review else 4) main_splitter.setStretchFactor(1, 1) main_splitter.setCollapsible(1, True) else: main_splitter.setSizes([1000]) layout.addWidget(main_splitter, row=2, col=0, colspan=2) self.main_splitter = main_splitter # Store reference for toggle self._update_go_to_controls() # Keep keyboard shortcuts active by default. self.plot.setFocusPolicy(Qt.StrongFocus) self.plot.setFocus() self.layout = layout self.layout.show() self._update_downsample_toggle()
def _default_splitter_sizes(self): if self._is_aimsz_review: if self.enable_image_panel and self.cutout_widget is not None: return [900, 280] return [980, 220] if self.review_panel is not None and self.cutout_widget is not None: return [780, 320] if self.review_panel is not None: return [860, 280] if self.cutout_widget is not None: return [800, 200] return [1000]
[docs] def keyPressEvent(self, event): """Forward keyboard events to plot widget.""" self.plot.keyPressEvent(event) self.sync_qa_checkbox_from_current_spec() self._update_go_to_controls() self.plot.setFocus()
def _current_index_one_based(self): if hasattr(self.plot, "_displaying_spectrum_number"): return int(self.plot._displaying_spectrum_number) counter = int(getattr(self.plot, "counter", 1)) if counter < 1: return 1 if counter > self.len_list: return self.len_list return counter def _update_go_to_controls(self): if not hasattr(self, "goto_index_spin"): return current_idx = self._current_index_one_based() self.goto_index_spin.blockSignals(True) self.goto_index_spin.setValue(current_idx) self.goto_index_spin.blockSignals(False)
[docs] def go_to_index(self): if not hasattr(self, "goto_index_spin"): return index_one_based = int(self.goto_index_spin.value()) self.plot.jump_to_spectrum(index_one_based) self.sync_qa_checkbox_from_current_spec() self._update_go_to_controls() self.plot.setFocus()
[docs] def mousePressEvent(self, event): """Forward mouse events to plot widget.""" self.plot.mousePressEvent(event)
[docs] def save_data(self): """Save current data to CSV.""" self.save_dict_todf() QMessageBox.information(self.layout, "Saved", f"Data saved to {self.output_file}")
[docs] def save_png(self): """Save the entire application window as a PNG image.""" def _sanitize_component(value): s = str(value).strip() s = s.replace(os.sep, "_").replace(" ", "_") s = re.sub(r"[^0-9A-Za-z_.-]+", "_", s) return s.strip("_") or "unknown" def _infer_survey(spec): dr = str(getattr(spec, "data_release", "") or getattr(spec, "_dr", "") or "") dr_l = dr.lower() if "desi" in dr_l: return "desi" if "sdss" in dr_l or "boss" in dr_l or "eboss" in dr_l: return "sdss" if "lamost" in dr_l: return "lamost" if "gaia" in dr_l: return "gaia" if dr_l: # keep it short and filesystem-friendly return _sanitize_component(dr_l)[:24] return "sparcl" spec = self.plot.spec if hasattr(self.plot, "spec") else None objid = getattr(spec, "objid", "spectrum") if spec is not None else "spectrum" objid_str = _sanitize_component(objid) if spec is not None and getattr(spec, "telescope", "").lower() == "euclid": base_name = f"euclid_{objid_str}_vi.png" elif ((("SpecSparcl" in globals()) and isinstance(spec, SpecSparcl)) or (("SpecAIMSZReview" in globals()) and isinstance(spec, SpecAIMSZReview))): survey = _infer_survey(spec) targetid = getattr(spec, "targetid", None) if targetid not in (None, "", 0): base_name = f"{survey}_{_sanitize_component(targetid)}_vi.png" else: base_name = f"{survey}_{objid_str}_vi.png" else: base_name = f"{objid_str}_vi.png" out_dir = Path.cwd() / "saved_pngs" try: out_dir.mkdir(parents=True, exist_ok=True) except Exception as e: QMessageBox.warning(self.layout, "Save PNG failed", f"Could not create {out_dir}: {e}") return filename = out_dir / base_name if filename.exists(): i = 2 while True: stem = filename.stem candidate = out_dir / f"{stem}_{i}.png" if not candidate.exists(): filename = candidate break i += 1 try: # Ensure latest visuals are painted before grabbing. self.layout.repaint() QApplication.processEvents() pixmap = self.layout.grab() ok = pixmap.save(str(filename), "PNG") except Exception as e: QMessageBox.warning(self.layout, "Save PNG failed", str(e)) return if not ok: QMessageBox.warning(self.layout, "Save PNG failed", "Qt failed to write the PNG file.") return QMessageBox.information(self.layout, "Saved", f"Saved PNG to {filename}")
[docs] def save_and_quit(self): """Save data and quit application.""" self.save_dict_todf() QMessageBox.information(self.layout, "Saved", f"Data saved to {self.output_file}") self.quit()
[docs] def toggle_image_panel(self): """Toggle visibility of the image cutout panel.""" if not self.enable_image_panel or self.cutout_widget is None: return if self.image_toggle_btn.isChecked(): if hasattr(self, "main_splitter"): self.main_splitter.setSizes([1000, 0]) self.cutout_widget.setVisible(False) self.image_toggle_btn.setText("Show Images") self.cutout_widget.auto_fetch_cb.setChecked(False) else: self.cutout_widget.setVisible(True) if hasattr(self, "main_splitter"): self.main_splitter.setSizes(self._default_splitter_sizes()) self.image_toggle_btn.setText("Hide Images") self.cutout_widget.auto_fetch_cb.setChecked(True)
[docs] def run_cross_correlation(self): """Run cross-correlation analysis - placeholder.""" QMessageBox.information(self, "Cross-Correlation", "Cross-correlation feature is coming soon!\n\n" "This will perform automatic redshift measurement\n" "using template cross-correlation.")
[docs] def save_dict_todf(self): """Save classification results to CSV.""" if not self.plot.history: return self._commit_current_review_state(update_timestamp=self.plot._is_aimsz_review) rows = self.plot._history_rows_for_csv() df_new = pd.DataFrame(rows) if self.plot._is_aimsz_review: ordered_cols = [ "objid", "targetid", "ra", "dec", "data_release", "class_vi", "z_vi", "qa_flag", "notes", "reviewer", "reviewed_at", ] df_new = df_new.reindex(columns=ordered_cols) df_new.to_csv(self.output_file, index=False) print(f"Results saved to {self.output_file}")
[docs] class PGSpecPlotThreadEnhanced(QThread): """Enhanced thread wrapper for the enhanced application.""" def __init__(self, spectra=None, SpecClass=SpecEuclid1d, specfile=None, **kwargs): super().__init__() explicit_buffer_dir = kwargs.pop("cutout_buffer_dir", None) explicit_enable_image_panel = "enable_image_panel" in kwargs if not explicit_enable_image_panel: kwargs["enable_image_panel"] = False self.enable_image_panel = bool(kwargs.get("enable_image_panel", True)) self.rgs_file = kwargs.get("rgs_file", None) self.bgs_file = kwargs.get("bgs_file", None) self.dual_ext = kwargs.get("ext", None) self.dual_extname = kwargs.get("extname", None) self.dual_good_pixels_only = bool(kwargs.get("dual_good_pixels_only", False)) self._dual_mode = bool(self.rgs_file and self.bgs_file) self._dual_parquet_mode = self._dual_mode and _is_dataframe_backed_spectrum_path(self.rgs_file) and _is_dataframe_backed_spectrum_path(self.bgs_file) self._dual_pair_keys = [] # Handle backward compatibility: if specfile is provided but spectra is not if spectra is None and specfile is not None: self.spectra = specfile elif spectra is not None: self.spectra = spectra elif self._dual_mode: self.spectra = self.rgs_file else: raise ValueError("Either 'spectra' or 'specfile' must be provided") self.SpecClass = SpecClass self.app = None self._skip_window = False self._disable_background_prefetch = not self.enable_image_panel self.buffer_dir = None if self._dual_parquet_mode: self._dual_pair_keys, _ = _ordered_shared_dual_pair_keys(self.rgs_file, self.bgs_file) if self.enable_image_panel: self.buffer_dir = Path(explicit_buffer_dir) if explicit_buffer_dir else self._resolve_buffer_dir(self.spectra) if self._should_offer_predownload(self.buffer_dir): if self._prompt_for_bulk_download(): self._skip_window = self._run_bulk_predownload() if not self._skip_window: if "enable_background_prefetch" not in kwargs: kwargs["enable_background_prefetch"] = not self._disable_background_prefetch self.app = PGSpecPlotAppEnhanced( self.spectra, self.SpecClass, cutout_buffer_dir=self.buffer_dir, **kwargs, ) @staticmethod def _resolve_buffer_dir(spectra): """Resolve cutout buffer path from input spectra.""" if isinstance(spectra, str): return Path(spectra).parent / "cutout_buffer" return Path(spectra[0]).parent / "cutout_buffer" @staticmethod def _count_hdus_in_file(filename): try: if hasattr(SpecEuclid1d, "count_in_file"): return int(SpecEuclid1d.count_in_file(filename)) except Exception: pass with fits.open(filename) as hdul: return max(0, len(hdul) - 1) @staticmethod def _should_offer_predownload(buffer_dir): """Offer predownload only when cutout buffer does not exist yet.""" return buffer_dir is not None and not Path(buffer_dir).exists() @staticmethod def _prompt_for_bulk_download(): """Prompt user in CLI to decide whether to pre-download all cutouts.""" if not sys.stdin or not sys.stdin.isatty(): print("No interactive terminal detected; continuing with on-the-fly cutout download.") return False while True: answer = input( "No 'cutout_buffer' folder found. Download all cutouts before launching the Qt window? [y/N]: " ).strip().lower() if answer in ("y", "yes"): return True if answer in ("", "n", "no"): return False print("Please answer 'y' or 'n'.") def _collect_cutout_records(self): """Collect objid/ra/dec records for all spectra.""" records = [] if self._dual_mode: if self._dual_parquet_mode: if self.dual_extname not in (None, ""): selectors = [(None, str(self.dual_extname))] elif self.dual_ext is not None: pair_index = int(self.dual_ext) - 1 if pair_index < 0 or pair_index >= len(self._dual_pair_keys): return records selectors = [(None, self._dual_pair_keys[pair_index])] else: selectors = [(None, key) for key in self._dual_pair_keys] elif self.dual_ext is not None or self.dual_extname not in (None, ""): selectors = [(self.dual_ext if self.dual_ext is not None else None, self.dual_extname)] else: try: total = self._count_hdus_in_file(self.rgs_file) except Exception as exc: print(f"Failed to determine dual spectrum count for predownload: {exc}") return records selectors = [(ext, self.dual_extname) for ext in range(1, total + 1)] for ext, extname in selectors: try: dual = SpecEuclid1dDual( rgs_file=self.rgs_file, bgs_file=self.bgs_file, ext=ext, extname=extname, good_pixels_only=self.dual_good_pixels_only, ) spec = dual.rgs if dual.rgs is not None else dual.bgs if spec is None: continue objid = getattr(spec, "objid", None) ra = getattr(spec, "ra", None) dec = getattr(spec, "dec", None) records.append({"objid": objid, "ra": ra, "dec": dec}) except Exception as exc: selector_desc = extname if extname not in (None, "") else ext print(f"Failed to load dual spectrum selector={selector_desc} for predownload: {exc}") return records if isinstance(self.spectra, (list, tuple, np.ndarray)): input_list = list(self.spectra) for filename in input_list: try: spec = self.SpecClass(filename) objid = getattr(spec, "objid", None) ra = getattr(spec, "ra", None) dec = getattr(spec, "dec", None) records.append({"objid": objid, "ra": ra, "dec": dec}) except Exception as exc: print(f"Failed to load spectrum '{filename}' for predownload: {exc}") return records try: if hasattr(self.SpecClass, "count_in_file"): total = int(self.SpecClass.count_in_file(self.spectra)) else: with fits.open(self.spectra) as hdul: total = len(hdul) - 1 except Exception as exc: print(f"Failed to determine spectrum count for predownload: {exc}") return records for ext in range(1, total + 1): try: spec = self.SpecClass(self.spectra, ext=ext) objid = getattr(spec, "objid", None) ra = getattr(spec, "ra", None) dec = getattr(spec, "dec", None) records.append({"objid": objid, "ra": ra, "dec": dec}) except Exception as exc: print(f"Failed to load spectrum ext={ext} for predownload: {exc}") return records def _run_bulk_predownload(self): """Run one-time bulk predownload and ask user to restart.""" print(f"Preparing bulk cutout download into '{self.buffer_dir}' ...") records = self._collect_cutout_records() if not records: print("No valid spectra records found for bulk predownload. Continuing with on-the-fly downloads.") return False summary = predownload_cutouts( records=records, buffer_dir=self.buffer_dir, surveys=EUCLID_CUTOUT_SURVEYS, size_arcsec=10, progress_callback=print_cli_progress, ) attempted = max(summary["total"] - summary["skipped"], 0) failed_total = summary["failed"] + summary["no_data"] fail_rate = failed_total / attempted if attempted > 0 else 0.0 print( "Bulk cutout download finished. " f"downloaded={summary['downloaded']}, skipped={summary['skipped']}, " f"no_data={summary['no_data']}, failed={summary['failed']}." ) if fail_rate >= 0.5: print( "Bulk predownload failure rate is high; continuing now with on-the-fly downloads " "and disabling background prefetch for this run." ) self._disable_background_prefetch = True return False print("Please run the program again to launch the Qt window with the prebuilt cutout buffer.") return True
[docs] def run(self): if self._skip_window: return exit_code = self.app.exec_() sys.exit(exit_code)