from flask import Flask
from flask_caching import Cache
from flask_socketio import SocketIO
from jupyter_dash import JupyterDash # as JupyterDash
from specdash import base_logs_directory, external_stylesheets, external_scripts, port, do_log, max_num_traces
import numpy as np
import uuid
from specdash import app_layout, callbacks
from datetime import datetime
import json
import multiprocessing
import base64
from .models.enum_models import WavelengthUnit, FluxUnit, SpectrumType
from .models.data_models import Trace, Spectrum, SpectralLine
from specdash.colors import get_next_color
import specdash.flux as fl
from specdash import input
from specdash.utilities import get_specid_list, get_unused_port
from .smoothing.smoother import Smoother, default_smoothing_kernels, SmoothingKernels
from .fitting.fitter import ModelFitter, default_fitting_models, FittingModels
from specutils import Spectrum1D, analysis, fitting, manipulation, SpectralRegion
import uuid
from dash import no_update
from collections import OrderedDict
process_manager = multiprocessing.Manager()
styles = {
'pre': {
'border': 'thin lightgrey solid',
'overflowX': 'scroll'
}
}
[docs]class Viewer():
"""
Class representing the spectrum viewer object.
"""
APP_DATA_KEYS = ["traces", "fitted_models", "selection", "smoothing_kernel_types", "fitting_model_types",
"redshift_distributions", "metadata", "line_analysis", 'axis_units', "updates",
'trace_store_mapping', 'zdist_store_mapping']
def __init__(self, as_website=False):
"""
Instantiates the viewer class, depending on whether it will run as a stand alone website or inside Jupyter.
:param as_website: boolean
Set as True if the viewer is intended to run as a stand alone website. and set to False if not (i.e., runs inside Jupyter)
"""
self.as_website = as_website
self.app_port = get_unused_port(initial_port=port)
self.as_website = as_website
if not self.as_website:
JupyterDash.infer_jupyter_proxy_config()
assets_ignore = ""
else:
assets_ignore = "websocket.js"
self.server = Flask(__name__) # define flask app.server
self.app = JupyterDash(__name__, external_stylesheets=external_stylesheets, external_scripts=external_scripts,
server=self.server, assets_ignore=assets_ignore, suppress_callback_exceptions=True)
if not as_website:
self.socketio = SocketIO(self.server, async_mode="threading", logger=True, engineio_logger=True)
self.app.server.secret_key = 'SOME_KEY_STRING'
self.app.title = "SpecDash"
self._initialize_app_data()
self.app_data_timestamp['timestamp'] = 0
self.debug_data = process_manager.dict()
self.smoother = Smoother()
self.model_fitter = ModelFitter()
self.storage_mode = "memory" if as_website else "memory"
self.app.layout = self._get_app_layout
session_id = str(uuid.uuid4())
callbacks.load_callbacks(self)
self.initialize_api_endpoints()
[docs] def initialize_api_endpoints(self):
@self.server.route('/api/health')
def health():
return {'is_healthy': True}
def _initialize_app_data(self):
self.app_data = process_manager.dict()
self.app_data_timestamp = process_manager.dict()
for key, value in Viewer.build_app_data().items():
self.app_data[key] = value
self._initialize_updates(self.app_data)
@staticmethod
def _initialize_updates(app_data):
_updates = {}
_updates["updated_traces"] = []
_updates["removed_traces"] = []
_updates["added_traces"] = []
_updates["updated_zdists"] = []
_updates["removed_zdists"] = []
_updates["added_zdists"] = []
app_data["updates"] = _updates
@staticmethod
def _set_trace_updates_info(app_data, added_trace_names=[], removed_trace_names=[], updated_trace_names=[],
added_zdist_names=[], removed_zdist_names=[], updated_zdist_names=[]):
if len(app_data['traces']) > max_num_traces or len(app_data['redshift_distributions']) > max_num_traces:
raise Exception("Maximum number of loaded traces exceeded")
_updates = app_data["updates"]
_added_traces = _updates["added_traces"] + [name for name in added_trace_names]
_updates["added_traces"] = list(OrderedDict.fromkeys(_added_traces))
_removed_traces = _updates["removed_traces"] + [name for name in removed_trace_names]
_updates["removed_traces"] = list(OrderedDict.fromkeys(_removed_traces))
_updated_traces = _updates["updated_traces"] + [name for name in updated_trace_names]
_updates["updated_traces"] = list(OrderedDict.fromkeys(_updated_traces))
_added_zdists = _updates["added_zdists"] + [name for name in added_zdist_names]
_updates["added_zdists"] = list(OrderedDict.fromkeys(_added_zdists))
_removed_zdists = _updates["removed_zdists"] + [name for name in removed_zdist_names]
_updates["removed_zdists"] = list(OrderedDict.fromkeys(_removed_zdists))
_updated_zdists = _updates["updated_zdists"] + [name for name in updated_zdist_names]
_updates["updated_zdists"] = list(OrderedDict.fromkeys(_updated_zdists))
app_data["updates"] = _updates
@staticmethod
def _build_trace_store_mapping(app_data):
if len(app_data['trace_store_mapping']) == 0:
trace_store_mapping = app_data['trace_store_mapping']
for index, trace_name in enumerate(app_data['traces']):
trace_store_mapping[trace_name] = index
app_data['trace_store_mapping'] = trace_store_mapping
@staticmethod
def _build_zdist_store_mapping(app_data):
if len(app_data['zdist_store_mapping']) == 0:
zdist_store_mapping = app_data['zdist_store_mapping']
for index, zdist_name in enumerate(app_data['redshift_distributions']):
zdist_store_mapping[zdist_name] = index
app_data['zdist_store_mapping'] = zdist_store_mapping
@staticmethod
def _get_returned_store_traces(app_data):
returned_store_traces = [no_update for i in range(max_num_traces)]
for name in app_data["updates"]["updated_traces"]:
i = app_data["trace_store_mapping"][name]
returned_store_traces[i] = app_data["traces"][name]
trace_store_mapping = app_data["trace_store_mapping"]
for name in app_data["updates"]["removed_traces"]:
i = app_data["trace_store_mapping"][name]
returned_store_traces[i] = None
trace_store_mapping.pop(name, None)
for name in app_data["updates"]["added_traces"]:
found = False
for i in range(max_num_traces):
if not found and i not in trace_store_mapping.values():
returned_store_traces[i] = app_data["traces"][name]
trace_store_mapping[name] = i
found = True
app_data["trace_store_mapping"] = trace_store_mapping
return returned_store_traces
@staticmethod
def _get_returned_store_zdists(app_data):
returned_store_zdists = [no_update for i in range(max_num_traces)]
for name in app_data['updates']["updated_zdists"]:
i = app_data["zdist_store_mapping"][name]
returned_store_zdists[i] = app_data["redshift_distributions"][name]
zdist_store_mapping = app_data["zdist_store_mapping"]
for name in app_data["updates"]["removed_zdists"]:
i = app_data["zdist_store_mapping"][name]
returned_store_zdists[i] = None
zdist_store_mapping.pop(name, None)
for name in app_data['updates']["added_zdists"]:
found = False
for i in range(max_num_traces):
if not found and i not in app_data["zdist_store_mapping"].values():
returned_store_zdists[i] = app_data["redshift_distributions"][name]
zdist_store_mapping[name] = i
found = True
app_data["zdist_store_mapping"] = zdist_store_mapping
return returned_store_zdists
def _get_app_layout(self):
return app_layout.load_app_layout(self=self)
[docs] def show_jupyter_app(self, debug=False, mode='jupyterlab'):
"""
Opens the Spectrum Viewer inside Jupyter.
:param debug:
:param mode:
:return:
"""
if not self.as_website:
self._initialize_app_data()
self.app.run_server(mode=mode, port=self.app_port, debug=debug, dev_tools_ui=True,
dev_tools_props_check=True, dev_tools_hot_reload=True,
dev_tools_silence_routes_logging=True) # dash + jupyterdash
[docs] @staticmethod
def build_app_data():
app_data = {}
for key in Viewer.APP_DATA_KEYS:
# app_data[key] = {}
app_data[key] = OrderedDict()
# initialize smoothing kernels with the list of default names
app_data["smoothing_kernel_types"] = default_smoothing_kernels
app_data['fitting_model_types'] = default_fitting_models
app_data['traces'] = OrderedDict()
app_data['trace_store_mapping'] = OrderedDict()
Viewer._initialize_updates(app_data)
Viewer._set_trace_updates_info(app_data)
return app_data
[docs] @staticmethod
def build_graph_settings(axis_units_changed=False):
graph_settings = {'axis_units_changed': axis_units_changed}
return graph_settings
def _parse_uploaded_file(self, contents, file_name, catalog_name, wavelength_unit=WavelengthUnit.ANGSTROM,
flux_unit=FluxUnit.F_lambda):
content_type, content_string = contents.split(',')
decoded_bytes = base64.b64decode(content_string)
if "." in file_name:
file_name_parts = file_name.split(".")
file_name = ".".join(file_name_parts[:(len(file_name_parts) - 1)])
return self._load_from_file(trace_name=file_name, catalog_name=catalog_name, decoded_bytes=decoded_bytes,
file_path=None, wavelength_unit=wavelength_unit, flux_unit=flux_unit)
def _load_from_specid_text(self, specid_text, wavelength_unit, flux_unit, data_dict, catalog_name,
do_update_client=False):
specid_list = get_specid_list(specid_text)
self._load_from_specid([s for s in specid_list], [s for s in specid_list], wavelength_unit, flux_unit,
data_dict, catalog_name, do_update_client=False)
if do_update_client:
self.update_client()
def _set_axis_units(self, data_dict, wavelength_unit, flux_unit):
if wavelength_unit is not None and flux_unit is not None:
axis_units = data_dict['axis_units']
axis_units['wavelength_unit'] = str(wavelength_unit)
axis_units['flux_unit'] = str(flux_unit)
data_dict['axis_units'] = axis_units
def _load_from_specid(self, specid_list, trace_name_list, wavelength_unit, flux_unit, data_dict, catalog_name,
do_update_client=False):
for ind, specid in enumerate(specid_list):
(spectrum_list, redshift_distributions) = input.load_data_from_specid(specid_list[ind], trace_name_list[ind],
catalog_name)
rescaled_traces = []
if wavelength_unit is None and len(spectrum_list) > 0:
wavelength_unit = spectrum_list[0].wavelength_unit
if flux_unit is None and len(spectrum_list) > 0:
flux_unit = spectrum_list[0].flux_unit
added_traces = []
added_zdists = []
for i in range(len(spectrum_list)):
spectrum = spectrum_list[i]
trace = spectrum.to_dict()
trace = self._get_rescaled_axis_in_trace(trace, to_wavelength_unit=wavelength_unit,
to_flux_unit=flux_unit)
if data_dict['axis_units'].get('wavelength_unit') is None or data_dict['axis_units'].get(
'wavelength_unit') is None:
self._set_axis_units(data_dict, wavelength_unit, flux_unit)
added_traces.append(trace)
self._set_colors_for_new_traces(new_traces=added_traces,
current_trace_colors=self._get_current_colors(data_dict))
if redshift_distributions is not None and len(redshift_distributions) > 0:
for i in range(len(spectrum_list)):
redshift_distribution = redshift_distributions[0] # only one zdist for now
if added_traces[i].get("name") == redshift_distribution.ancestors[0]:
redshift_distribution.color = added_traces[i]['color']
added_zdists.append(redshift_distribution.to_dict())
self._add_trace_to_data(data_dict, added_traces, do_update_client=False)
self._add_zdist_to_data(data_dict, added_zdists, do_update_client=False)
if do_update_client:
self.update_client()
def _load_from_file(self, trace_name, catalog_name, decoded_bytes=None, file_path=None,
wavelength_unit=WavelengthUnit.ANGSTROM, flux_unit=FluxUnit.F_lambda):
# assumes that spectrum wavelength units are in Armstrong:
spectrum_list = input.load_data_from_file(trace_name, catalog_name, decoded_bytes, file_path)
if wavelength_unit is None and len(spectrum_list) > 0:
wavelength_unit = spectrum_list[0].wavelength_unit
if flux_unit is None and len(spectrum_list) > 0:
flux_unit = spectrum_list[0].flux_unit
rescaled_traces = []
for spectrum in spectrum_list:
trace = spectrum.to_dict()
rescaled_traces.append(self._get_rescaled_axis_in_trace(trace, to_wavelength_unit=wavelength_unit,
to_flux_unit=flux_unit))
return rescaled_traces
def _add_spectrum_from_file(self, file_path, data_dict, wavelength_unit, flux_unit, catalog_name, trace_name=None,
do_update_client=False):
if trace_name is None:
file_path_parts = file_path.split("/")
name_parts = file_path_parts[-1].split(".")
trace_name = ".".join(name_parts[:(len(name_parts) - 1)])
rescaled_traces = self._load_from_file(trace_name, catalog_name=catalog_name, decoded_bytes=None,
file_path=file_path, wavelength_unit=wavelength_unit,
flux_unit=flux_unit)
if wavelength_unit is None and len(rescaled_traces) > 0:
wavelength_unit = rescaled_traces[0].get('wavelength_unit')
if flux_unit is None and len(rescaled_traces) > 0:
flux_unit = rescaled_traces[0].get('flux_unit')
self._set_axis_units(data_dict, wavelength_unit, flux_unit)
self._set_colors_for_new_traces(rescaled_traces, self._get_current_colors(data_dict))
self._add_trace_to_data(data_dict, rescaled_traces, do_update_client=False)
if do_update_client:
self.update_client()
def _get_current_colors(self, application_data):
current_traces_colors = [application_data['traces'][trace_name]['color'] for trace_name in
application_data['traces']]
current_traces_colors += [application_data['redshift_distributions'][trace_name]['color'] for trace_name in
application_data['redshift_distributions']]
return current_traces_colors
def _set_color_for_new_trace(self, trace, application_data):
current_traces_colors = self._get_current_colors(application_data)
new_color = get_next_color(current_traces_colors)
trace['color'] = new_color
def _set_colors_for_new_traces(self, new_traces, current_trace_colors):
new_colors = self._get_colors_for_new_traces(current_trace_colors, num_output_colors=len(new_traces))
for i in range(len(new_traces)):
new_traces[i]['color'] = new_colors[i]
def _get_colors_for_new_traces(self, current_trace_colors, num_output_colors=1):
# current_trace_colors = self._get_current_colors(application_data)
new_colors = []
for i in range(num_output_colors):
next_color = get_next_color(current_trace_colors + new_colors)
new_colors.append(next_color)
return new_colors
def _synch_data(self, base_data_dict, incomplete_data_dict, do_update_client=False):
# self.write_info("inc0 start " + str(incomplete_data_dict) + " " + str(base_data_dict))
for key in Viewer.APP_DATA_KEYS:
incomplete_data_dict[key] = base_data_dict[key]
if do_update_client:
self.update_client()
def _add_trace_to_data(self, application_data, trace, do_update_client=False):
if type(trace) != list:
_traces = [trace]
else:
_traces = trace
traces = application_data['traces']
for trace in _traces:
if trace.get('name') in traces:
raise Exception("Trace named '" + trace.get('name') + "' already exists.")
for trace in _traces:
traces[trace.get('name')] = trace
application_data['traces'] = traces
self._set_trace_updates_info(application_data, added_trace_names=[trace.get('name') for trace in _traces])
if do_update_client:
self.update_client()
def _add_zdist_to_data(self, application_data, redshift_distribution, do_update_client=True):
if type(redshift_distribution) != list:
_zdists = [redshift_distribution]
else:
_zdists = redshift_distribution
zdists = application_data['redshift_distributions']
for zdist in _zdists:
if zdist.get('name') in zdists:
raise Exception("Redshift distribution named '" + zdist.get('name') + "' already exists.")
for zdist in _zdists:
zdists[zdist.get('name')] = zdist
application_data['redshift_distributions'] = zdists
self._set_trace_updates_info(application_data, added_zdist_names=[zdist.get('name') for zdist in _zdists])
if do_update_client:
self.update_client()
def _remove_traces(self, trace_names, data_dict, do_update_client=True, also_remove_children=False):
# add derived traces to be removed: iterate over traces and find the ones whose ancestors are in the 'trace_names'
_traces_to_remove = [name for name in trace_names]
for name in data_dict['traces']:
for ancestor in data_dict['traces'][name]['ancestors']:
if ancestor in _traces_to_remove:
if also_remove_children:
# remove all children.
_traces_to_remove.append(name)
else:
# remove only if it is not visible
if data_dict['traces'][name]['is_visible'] == False:
_traces_to_remove.append(name)
# remove duplicates
_traces_to_remove = set(_traces_to_remove)
traces = data_dict['traces']
for trace_name in _traces_to_remove:
traces.pop(trace_name)
data_dict['traces'] = traces
self._set_trace_updates_info(data_dict, removed_trace_names=[n for n in _traces_to_remove])
# delete redshift distributions associated with traces to remove
_zdists_to_remove = []
for zdist_name in data_dict['redshift_distributions']:
for ancestor in data_dict['redshift_distributions'][zdist_name]['ancestors']:
if ancestor in _traces_to_remove:
_zdists_to_remove.append(zdist_name)
_zdists_to_remove = set(_zdists_to_remove)
zdists = data_dict['redshift_distributions']
for zdist_name in _zdists_to_remove:
zdists.pop(zdist_name)
data_dict['redshift_distributions'] = zdists
self._set_trace_updates_info(data_dict, removed_zdist_names=[n for n in _zdists_to_remove])
# remove fitting models associated with traces to remove
_fitted_models_to_remove = []
for fitted_model_name in data_dict['fitted_models']:
for ancestor in data_dict['fitted_models'][fitted_model_name]['ancestors']:
if ancestor in _traces_to_remove:
_fitted_models_to_remove.append(fitted_model_name)
_fitted_models_to_remove = set(_fitted_models_to_remove)
fitted_models = data_dict['fitted_models']
for fitted_model_name in _fitted_models_to_remove:
fitted_models.pop(fitted_model_name)
data_dict['fitted_models'] = fitted_models
if do_update_client:
self.update_client()
def _toggle_derived_traces(self, derived_trace_type, ancestor_trace_names, data_dict, do_update_client=False):
ancestor_trace_names = np.asarray(ancestor_trace_names)
traces = data_dict['traces']
toggled_trace_names = []
for derived_trace_name in traces:
trace = traces[derived_trace_name]
has_selected_ancestors = np.any(np.in1d(ancestor_trace_names, np.asarray(trace.get('ancestors'))))
if has_selected_ancestors:
if trace["spectrum_type"] == derived_trace_type:
# consider only the first derived trace of a particular type
if trace.get("inner_type_rank") == 1:
# self.write_info("Toggle " + derived_trace_name + " from is_visible=" + str(data_dict['traces'][derived_trace_name]["is_visible"]))
trace["is_visible"] = False if trace["is_visible"] == True else True
# self.write_info("Toggle " + derived_trace_name + " intermediate to is_visible=" + str(trace["is_visible"]))
traces[derived_trace_name] = trace
toggled_trace_names.append(derived_trace_name)
# self.write_info("Toggle " + derived_trace_name + " to is_visible=" + str(data_dict['traces'][derived_trace_name]["is_visible"]))
data_dict["traces"] = traces
self._set_trace_updates_info(data_dict, updated_trace_names=toggled_trace_names)
if do_update_client:
self.update_client()
def _include_derived_traces(self, spectrum_types, ancestor_trace_names, data_dict, do_update_client=False):
ancestor_trace_names = np.asarray(ancestor_trace_names)
for derived_trace_name in data_dict['traces']:
trace = data_dict['traces'][derived_trace_name]
has_selected_ancestors = np.any(np.in1d(ancestor_trace_names, np.asarray(trace.get('ancestors'))))
if has_selected_ancestors:
if trace["spectrum_type"] in spectrum_types:
trace["is_visible"] = True
else:
trace["is_visible"] = False
data_dict[derived_trace_name] = trace
if do_update_client:
self.update_client()
[docs] def write_info(self, info, file_endding=''):
if do_log:
if file_endding != '':
file_endding = '_' + file_endding
with open(base_logs_directory + 'info' + file_endding + '.txt', 'a+') as f:
f.write(str(datetime.now()) + " " + info + "\n")
def __set_app_data_timestamp(self, timestamp=None):
if timestamp is not None:
self.app_data_timestamp['timestamp'] = timestamp # in sceconds
else:
self.app_data_timestamp['timestamp'] = datetime.timestamp(datetime.now()) # in sceconds
self.write_info("Updated timestamp to " + str(self.app_data_timestamp['timestamp']))
[docs] def update_client(self, component_names=[], timestamp=None):
# self.__set_app_data_timestamp(timestamp)
# https://stackoverflow.com/questions/28947581/how-to-convert-a-dictproxy-object-into-json-serializable-dict
# self._send_websocket_message(json.dumps(self.app_data.copy()))
message = {'component_names': component_names, 'timestamp': timestamp}
self._send_websocket_message(json.dumps(message))
def _send_websocket_message(self, message):
self.socketio.emit("update", message)
[docs] def get_data_dict(self, data):
# return json.loads(data) if data is not None else self.build_app_data()
return data if data is not None else self.build_app_data()
def _unsmooth_trace(self, trace_names, application_data, do_update_client=True):
for trace_name in trace_names:
traces = application_data['traces']
trace = traces[trace_name]
# use original flux stored as flambda
flux = fl.convert_flux(flux=trace['flambda'], wavelength=trace['wavelength'],
from_flux_unit=FluxUnit.F_lambda, to_flux_unit=trace.get('flux_unit'),
to_wavelength_unit=trace.get('wavelength_unit'))
trace['flux'] = flux
traces[trace_name] = trace
application_data['traces'] = traces
self._set_trace_updates_info(application_data, updated_trace_names=[n for n in trace_names])
if do_update_client:
self.update_client()
def _get_smoother(self, smoothing_kernel, kernel_width):
if smoothing_kernel in default_smoothing_kernels:
smoother = Smoother()
smoother.set_smoothing_kernel(kernel=smoothing_kernel, kernel_width=int(kernel_width))
else: # use smoother defined by user:
smoother = self.smoother
return smoother
def _smooth_trace(self, trace_names, application_data, smoother, do_update_client=True, do_substract=False,
as_new_trace=False, new_trace_name=None):
added_trace_names = []
for trace_name in trace_names:
if trace_name in application_data['traces']:
traces = application_data['traces']
trace = traces[trace_name]
flux = fl.convert_flux(flux=trace['flambda'], wavelength=trace['wavelength'],
from_flux_unit=FluxUnit.F_lambda, to_flux_unit=trace.get('flux_unit'),
to_wavelength_unit=trace.get('wavelength_unit'))
smoothed_flux = smoother.get_smoothed_flux(flux)
if do_substract:
smoothed_flux = flux - smoothed_flux
if not as_new_trace:
trace['flux'] = smoothed_flux
traces[trace_name] = trace
else:
if new_trace_name is None:
names = [name for name in application_data['traces'] if
trace_name in application_data['traces'][name][
"ancestors"] and SpectrumType.SMOOTHED in application_data['traces'][name][
"spectrum_type"]]
smoothed_trace_name = "smoothed_" + str(len(names) + 1) + "_" + trace_name
else:
smoothed_trace_name = new_trace_name
ancestors = trace['ancestors'] + [trace_name]
f_labmda = fl.convert_flux(flux=[y for y in smoothed_flux],
wavelength=[x for x in trace["wavelength"]],
from_flux_unit=trace['flux_unit'], to_flux_unit=FluxUnit.F_lambda,
to_wavelength_unit=WavelengthUnit.ANGSTROM)
smoothed_trace = Trace(name=smoothed_trace_name, wavelength=[x for x in trace["wavelength"]],
flux=[y for y in smoothed_flux], flux_error=trace.get('flux_error'),
ancestors=ancestors, spectrum_type=SpectrumType.SMOOTHED, color="black",
linewidth=1, alpha=1.0, wavelength_unit=trace['wavelength_unit'],
flux_unit=trace['flux_unit'], flambda=f_labmda,
flambda_error=trace.get("flambda_error"), catalog=trace['catalog']).to_dict()
self._set_color_for_new_trace(smoothed_trace, application_data)
traces[smoothed_trace_name] = smoothed_trace
added_trace_names.append(smoothed_trace_name)
application_data['traces'] = traces
# if kernel is custom, add it to the data dict:
current_smoothing_kernels = application_data['smoothing_kernel_types']
self.write_info("current_smoothing_kernels1: " + str(current_smoothing_kernels))
if smoother.kernel_func_type not in current_smoothing_kernels:
current_smoothing_kernels.append(smoother.kernel_func_type)
# self.write_info("current_smoothing_kernels2: " + str(current_smoothing_kernels))
application_data['smoothing_kernel_types'] = current_smoothing_kernels
# self.write_info("application_data['smoothing_kernel_types']: " + str(application_data['smoothing_kernel_types']))
if as_new_trace:
self._set_trace_updates_info(application_data, added_trace_names=added_trace_names)
else:
self._set_trace_updates_info(application_data, updated_trace_names=[n for n in trace_names if
n in application_data['traces']])
if do_update_client:
self.update_client()
def _rescale_axis(self, application_data, to_wavelength_unit=WavelengthUnit.ANGSTROM,
to_flux_unit=FluxUnit.F_lambda, do_update_client=False):
traces = application_data['traces']
for trace_name in traces:
rescaled_trace = self._get_rescaled_axis_in_trace(traces[trace_name], to_wavelength_unit=to_wavelength_unit,
to_flux_unit=to_flux_unit)
traces[trace_name] = rescaled_trace
application_data['traces'] = traces
self._set_axis_units(application_data, to_wavelength_unit, to_flux_unit)
self._set_trace_updates_info(application_data, updated_trace_names=[name for name in traces])
if do_update_client:
self.update_client()
def _get_rescaled_axis_in_trace(self, trace, to_wavelength_unit=WavelengthUnit.ANGSTROM,
to_flux_unit=FluxUnit.F_lambda):
# Documentation:
# https://synphot.readthedocs.io/en/latest/synphot/units.html
# https://synphot.readthedocs.io/en/latest/api/synphot.units.convert_flux.html#synphot.units.convert_flux
# for wavelength axis:
trace['wavelength'] = fl.convert_wavelength(wavelength=trace['wavelength'],
from_wavelength_unit=trace['wavelength_unit'],
to_wavelength_unit=to_wavelength_unit)
trace['wavelength_unit'] = to_wavelength_unit
# for flux axis:
if trace.get('flux_unit') == FluxUnit.AB_magnitude and to_flux_unit != FluxUnit.AB_magnitude:
if trace.get("flambda") is not None and len(trace.get("flambda")) > 0:
trace['flux'] = fl.convert_flux(flux=trace.get("flambda"), wavelength=trace['wavelength'],
from_flux_unit=FluxUnit.F_lambda, to_flux_unit=to_flux_unit,
to_wavelength_unit=to_wavelength_unit)
trace['flux_error'] = fl.convert_flux(flux=trace.get("flambda_error"), wavelength=trace['wavelength'],
from_flux_unit=FluxUnit.F_lambda, to_flux_unit=to_flux_unit,
to_wavelength_unit=to_wavelength_unit)
else:
trace['flux'] = fl.convert_flux(flux=trace['flux'], wavelength=trace['wavelength'],
from_flux_unit=trace.get('flux_unit'), to_flux_unit=to_flux_unit,
to_wavelength_unit=to_wavelength_unit)
trace['flux_error'] = fl.convert_flux(flux=trace['flux_error'], wavelength=trace['wavelength'],
from_flux_unit=trace.get('flux_unit'), to_flux_unit=to_flux_unit,
to_wavelength_unit=to_wavelength_unit)
else:
trace['flux'] = fl.convert_flux(flux=trace['flux'], wavelength=trace['wavelength'],
from_flux_unit=trace.get('flux_unit'), to_flux_unit=to_flux_unit,
to_wavelength_unit=to_wavelength_unit)
trace['flux_error'] = fl.convert_flux(flux=trace['flux_error'], wavelength=trace['wavelength'],
from_flux_unit=trace.get('flux_unit'), to_flux_unit=to_flux_unit,
to_wavelength_unit=to_wavelength_unit)
trace['flux_unit'] = to_flux_unit
return trace
def _bin_wavelength_axis(self, trace_names, application_data, bin_size, wavelength_unit, flux_unit,
do_update_client=False):
# convert to
if bin_size <= 0:
raise Exception("bin_size should be grater than 0.")
added_traces = []
for trace_name in trace_names:
trace = application_data['traces'][trace_name].copy()
wave_binned = []
flux_binned = []
flux_err_binned = []
n = len(trace['flux'])
for i in range(0, n, bin_size):
if i < n - bin_size + 1:
wave_binned.append(np.mean(trace['wavelength'][i:i + bin_size]))
flux_binned.append(np.mean(trace['flux'][i:i + bin_size]))
if len(trace['flux_error']) > 0:
flux_err_binned.append(np.mean(trace['flux_error'][i:i + bin_size]))
flambda = fl.convert_flux(flux=flux_binned, wavelength=wave_binned, from_flux_unit=flux_unit,
to_flux_unit=FluxUnit.F_lambda, to_wavelength_unit=WavelengthUnit.ANGSTROM)
flambda_err = fl.convert_flux(flux=flux_err_binned, wavelength=wave_binned, from_flux_unit=flux_unit,
to_flux_unit=FluxUnit.F_lambda, to_wavelength_unit=WavelengthUnit.ANGSTROM)
trace['wavelength'] = wave_binned
trace['flux'] = flux_binned
trace['flux_error'] = flux_err_binned
trace['flambda'] = [x for x in flambda]
trace['flambda_err'] = [x for x in flambda_err]
trace['ancestors'] = trace['ancestors'] + [trace_name]
for i in range(100): # revise
name = "binned_" + str(i) + "_" + trace['name']
if name not in application_data['traces']:
trace['name'] = name
break
if trace['name'] == trace_name:
trace['name'] = "binned_" + str(uuid.uuid1()) + "_" + trace_name
added_traces.append(trace)
self._set_colors_for_new_traces(added_traces, self._get_current_colors(application_data))
self._add_trace_to_data(application_data, added_traces)
if do_update_client:
self.update_client()
def _get_model_fitter(self, trace_name, application_data, fitting_model, selected_data):
# trace = application_data['traces'].get(trace_name)
curve_number = self._get_curve_mapping(application_data)[trace_name]
x = np.asarray([point['x'] for point in selected_data["points"] if point['curveNumber'] == curve_number])
y = np.asarray([point['y'] for point in selected_data["points"] if point['curveNumber'] == curve_number])
if fitting_model in default_fitting_models:
model, fitter = ModelFitter.get_model_with_fitter(fitting_model, x, y)
model_type = fitting_model
else:
fitter = self.model_fitter.fitter
model = self.model_fitter.model
model_type = FittingModels.CUSTOM
return ModelFitter(model, fitter, model_type)
def _fit_model_to_flux(self, trace_names, application_data, model_fitters, selected_data, median_filter_width=1,
do_update_client=False, add_fit_substracted_trace=False):
# Documentation:
# http://learn.astropy.org/rst-tutorials/User-Defined-Model.html
# https://docs.astropy.org/en/stable/modeling/new-model.html
# https://docs.astropy.org/en/stable/modeling/index.html
# https://docs.astropy.org/en/stable/modeling/reference_api.html
# added_trace_names = []
fitted_info_list = []
fitted_traces = []
for model_fitter in model_fitters:
for trace_name in trace_names:
trace = application_data['traces'].get(trace_name)
curve_number = self._get_curve_mapping(application_data)[trace_name]
x = np.asarray(
[point['x'] for point in selected_data["points"] if point['curveNumber'] == curve_number])
y = np.asarray(
[point['y'] for point in selected_data["points"] if point['curveNumber'] == curve_number])
ind = [point['pointIndex'] for point in selected_data["points"] if point['curveNumber'] == curve_number]
y_err = np.asarray(trace["flux_error"])[ind] if trace["flux_error"] is not None or \
len(trace["flux_error"]) > 0 else None
if median_filter_width is not None and median_filter_width >= 1:
pass
min_x, max_x = np.min(x), np.max(x)
wavelength = np.array(trace["wavelength"])
ind2 = (wavelength >= min_x) & (wavelength <= max_x)
wave = wavelength[ind2]
fitter = model_fitter.fitter
model = model_fitter.model
fitting_model_type = model_fitter.model_type
self.write_info("x :" + str(x))
self.write_info("y :" + str(y))
self.write_info("err :" + str(y_err))
fitted_model = fitter(model, x, y, weights=1. / y_err)
x_grid = np.linspace(min_x, max_x, 5 * len(x))
x_grid = wave
y_grid = fitted_model(x_grid)
parameter_errors = np.sqrt(np.diag(fitter.fit_info.get('param_cov'))) if fitter.fit_info.get(
'param_cov') is not None else [None for x in fitted_model.parameters]
parameters_covariance_matrix = fitter.fit_info.get('param_cov')
fitted_trace_name = "fit" + str(len(application_data['fitted_models']) + 1) + "_" + trace_name
ancestors = trace['ancestors'] + [trace_name]
flambda = [f for f in np.asarray(trace['flambda'])[ind]]
fitted_trace = Trace(name=fitted_trace_name, wavelength=[x for x in x_grid], flux=[y for y in y_grid],
ancestors=ancestors, spectrum_type=SpectrumType.FIT, color="black", linewidth=1,
alpha=1.0, wavelength_unit=trace['wavelength_unit'], flux_unit=trace['flux_unit'],
flambda=flambda, catalog=trace['catalog']).to_dict()
fitted_traces.append(fitted_trace)
if add_fit_substracted_trace:
fitted_trace_name = "fitsub_" + str(len(application_data['fitted_models']) + 1) + "_" + trace_name
ancestors = trace['ancestors'] + [trace_name]
# flux = np.array(trace.get('flux'))
flux = np.array(trace["flux"])[ind2]
flux = flux - fitted_model(wave)
# flux[ind] = diff
f_labmda = fl.convert_flux(flux=flux, wavelength=wave,
from_flux_unit=trace['flux_unit'], to_flux_unit=FluxUnit.F_lambda,
to_wavelength_unit=WavelengthUnit.ANGSTROM)
fitted_trace = Trace(name=fitted_trace_name, wavelength=[x for x in wave],
flux=[y for y in flux],
ancestors=ancestors, spectrum_type=SpectrumType.FIT, color="black",
linewidth=1, alpha=1.0,
wavelength_unit=trace['wavelength_unit'], flux_unit=trace['flux_unit'],
flambda=f_labmda, catalog=trace['catalog']).to_dict()
fitted_traces.append(fitted_trace)
# equivalent width
fitted_info = {}
fitted_info['trace_name'] = fitted_trace_name
fitted_info['original_trace'] = trace_name
fitted_info['ancestors'] = ancestors + [fitted_trace_name]
fitted_info['model'] = fitting_model_type
fitted_info['parameter_names'] = [x for x in fitted_model.param_names]
fitted_info['parameter_values'] = {x: y for (x, y) in
zip(fitted_model.param_names, fitted_model.parameters)}
fitted_info['covariance'] = parameters_covariance_matrix
fitted_info['parameter_errors'] = {x: y for (x, y) in zip(fitted_model.param_names, parameter_errors)}
fitted_info['selection_indexes'] = ind
fitted_info['wavelength_unit'] = trace['wavelength_unit']
fitted_info['flux_unit'] = trace['flux_unit']
fitted_info_list.append(fitted_info)
current_fitting_model_types = application_data['fitting_model_types']
if fitting_model_type not in current_fitting_model_types:
current_fitting_model_types.append(fitting_model_type)
application_data['fitting_model_types'] = current_fitting_model_types
fitted_info_list.append(fitted_info)
self._set_colors_for_new_traces(fitted_traces, self._get_current_colors(application_data))
self._add_trace_to_data(application_data, fitted_traces, do_update_client=False)
fitted_models = application_data['fitted_models']
for fitted_model_info in fitted_info_list:
fitted_models[fitted_model_info.get('trace_name')] = fitted_model_info
application_data['fitted_models'] = fitted_models
if do_update_client:
self.update_client()
return fitted_info_list
def _get_selection(self, selected_data, data_dict):
selection = {}
if selected_data != {} and selected_data is not None:
selection = {key: value for key, value in selected_data.items()}
# adding trace name to each points:
curve_mapping_reversed = {ind: name for (name, ind) in self._get_curve_mapping(data_dict).items()}
points = []
for point in selection['points']:
point['trace_name'] = curve_mapping_reversed[point['curveNumber']]
points.append(point)
selection['points'] = points
return selection
return selection
def _set_selection(self, trace_name, application_data, selection_indices=[], do_update_client=False):
# curve_mapping = {name: ind for ind, name in enumerate(application_data['traces'])}
curve_number = self._get_curve_mapping(application_data)[trace_name]
selection = application_data["selection"]
new_points = [point for point in selection["points"] if point['curveNumber'] != curve_number]
trace_points = []
for ind in selection_indices:
point = {}
point['curveNumber'] = curve_number
point['pointNumber'] = ind
point['pointIndex'] = ind
point['trace_name'] = trace_name
point['x'] = application_data['traces'][trace_name]["wavelength"][ind]
point['y'] = application_data['traces'][trace_name]["flux"][ind]
trace_points.append(point)
new_points = new_points + trace_points
selection['points'] = new_points
application_data["selection"] = selection
if do_update_client:
self.update_client()
def _update_trace_properties(self, data_dict, properties_list, do_update_client=True, also_remove_children=False):
trace_list = [data_dict['traces'][trace] for trace in data_dict['traces']]
for properties in properties_list:
rank = int(properties['rank'])
trace = trace_list[rank]
old_trace_name = trace['name']
new_trace_name = properties['name']
for property in properties:
if property in trace and property != 'ancestors':
trace[property] = properties[property]
trace_list[rank] = trace
if new_trace_name is not None and new_trace_name != old_trace_name:
# change name in ancestors of other traces:
trace_list2 = []
for i, trace2 in enumerate(trace_list):
ancestors = trace2['ancestors']
if old_trace_name in ancestors:
ancestors = [new_trace_name if ancestor == old_trace_name else ancestor for ancestor in
ancestors]
trace2['ancestors'] = ancestors
trace_list2.append(trace2)
trace_list = [t for t in trace_list2]
data_dict['traces'] = {trace['name']: trace for trace in trace_list}
traces_in_properties = [prop['name'] for prop in properties_list]
traces_to_delete = [trace_name for trace_name in data_dict['traces'] if trace_name not in traces_in_properties]
self._remove_traces(traces_to_delete, data_dict, do_update_client=False,
also_remove_children=also_remove_children)
self._set_trace_updates_info(data_dict, removed_trace_names=traces_to_delete,
updated_trace_names=traces_in_properties)
if do_update_client:
self.update_client()
def _get_line_analysis(self, trace_names, application_data, selected_data, continuum_trace, median_window=10,
as_new_trace=False, do_update_client=False):
# Documentation:
# http://learn.astropy.org/rst-tutorials/User-Defined-Model.html
# https://docs.astropy.org/en/stable/modeling/new-model.html
# https://docs.astropy.org/en/stable/modeling/index.html
# https://docs.astropy.org/en/stable/modeling/reference_api.html
line_analysis_dict = {}
continuum = application_data['traces'][continuum_trace]
continuum_x = np.asarray(continuum['wavelength'])
continuum_y = np.asarray(continuum['flux'])
added_traces = []
for trace_name in trace_names:
trace = application_data['traces'].get(trace_name)
curve_number = self._get_curve_mapping(application_data)[trace_name]
x = np.asarray([point['x'] for point in selected_data["points"] if point['curveNumber'] == curve_number])
y = np.asarray([point['y'] for point in selected_data["points"] if point['curveNumber'] == curve_number])
ind = [point['pointIndex'] for point in selected_data["points"] if point['curveNumber'] == curve_number]
y_err = np.asarray(trace["flux_error"])[ind] if trace["flux_error"] is not None or len(
trace["flux_error"]) > 0 else None
spectrum = Spectrum1D(spectral_axis=x * WavelengthUnit.get_astropy_unit(trace["wavelength_unit"]),
flux=y * FluxUnit.get_astropy_unit(trace["flux_unit"]))
ind2 = []
for _x in x:
i = np.where(continuum_x == _x)
if len(i[0]) > 0:
ind2.append(i[0][0])
continuum_spec = Spectrum1D(
spectral_axis=continuum_x[ind2] * WavelengthUnit.get_astropy_unit(continuum["wavelength_unit"]),
flux=continuum_y[ind2] * FluxUnit.get_astropy_unit(continuum["flux_unit"]))
norm_spectrum = spectrum / continuum_spec
diff_spectrum = spectrum - continuum_spec
specline = SpectralLine()
specline.line = "user-defined"
specline.wavelength = analysis.centroid(spectrum=norm_spectrum, region=None).value
specline.ew = analysis.equivalent_width(spectrum=norm_spectrum, regions=None).value
specline.area = analysis.line_flux(diff_spectrum, regions=None).value
specline.sigma = analysis.gaussian_sigma_width(diff_spectrum).value
specline.cont_level = np.mean(continuum_spec.flux).value
specline.wavelength_unit = trace['wavelength_unit']
specline.flux_unit = trace['flux_unit']
line_trace = Trace()
line_trace_name = "Line_" + trace_name
line_trace.name = line_trace_name
line_trace.catalog = trace['catalog']
line_trace.wavelength = [i for i in x]
line_trace.flux = [i for i in y]
line_trace.flux_error = [i for i in y_err]
line_trace.wavelength_unit = trace['wavelength_unit']
line_trace.flux_unit = trace['flux_unit']
line_trace.spectrum_type = SpectrumType.LINE
line_trace.spectral_lines = [specline.to_dict()]
line_trace.ancestors = trace['ancestors'] + [trace_name]
line_trace.flambda = [x for x in np.array(trace['flambda'])[ind]]
line_trace.flambda_error = [x for x in np.array(trace['flambda_error'])[ind]]
line_trace.linewidth = 1
line_trace.color = "black"
line_trace = line_trace.to_dict()
added_traces.append(line_trace)
# added_trace_names.append(line_trace_name)
self._set_colors_for_new_traces(added_traces, self._get_current_colors(application_data))
self._add_trace_to_data(application_data, added_traces, do_update_client=False)
if do_update_client:
self.update_client()
def _get_curve_mappingOLD(self, application_data):
curve_mapping = {}
for index, trace_name in enumerate(application_data['traces']):
if application_data['traces'][trace_name]['is_visible']:
curve_mapping[trace_name] = index
return curve_mapping
def _get_curve_mapping(self, application_data):
curve_mapping = {}
index = 0
for trace_name in application_data['traces']:
if application_data['traces'][trace_name]['is_visible']:
curve_mapping[trace_name] = index
index += 1
return curve_mapping
######################################################################################################################
######## Trace manipulation/analysis functions fro execution within Jupyter ##################################################################
[docs] def set_smoothing_kernel(self, kernel=SmoothingKernels.GAUSSIAN1D, kernel_width=20, custom_array_kernel=None,
custom_kernel_function=None, function_array_size=21):
"""
Sets the smoothing kernel from several kernel options
:param kernel:
:param kernel_width:
:param custom_array_kernel:
:param custom_kernel_function:
:param function_array_size:
:return:
"""
self.smoother.set_smoothing_kernel(kernel, kernel_width, custom_array_kernel, custom_kernel_function,
function_array_size)
if self.smoother.kernel_func_type not in self.app_data["smoothing_kernel_types"]:
smoothing_kernel_types = self.app_data.get("smoothing_kernel_types")
smoothing_kernel_types.append(self.smoother.kernel_func_type)
self.app_data["smoothing_kernel_types"] = smoothing_kernel_types
self.update_client()
[docs] def smooth_trace(self, trace_name, do_substract=False):
"""
Smooths a trace after the kernel is set with 'set_smoothing_kernel'
:param trace_name: name of trace
:param do_substract: True if the smoothed trace is substraced from the original trace. False otherwise.
:return:
"""
self._initialize_updates(self.app_data)
self._smooth_trace([trace_name], self.app_data, self.smoother, do_update_client=True, do_substract=do_substract)
[docs] def reset_smoothing(self, trace_name):
"""
Resets the smoothing previously done on trace
:param trace_name: name of trace
:return:
"""
self._initialize_updates(self.app_data)
self._unsmooth_trace([trace_name], self.app_data, do_update_client=True)
[docs] def set_custom_model_fitter(self, model, fitter):
"""
Sets the instances of a model and fitter in order to perform model fitting on a trace.
:param model: model instance
:param fitter: fitter instance
:return:
"""
self.model_fitter = ModelFitter(model, fitter, FittingModels.CUSTOM)
if self.model_fitter.model_type not in self.app_data["fitting_model_types"]:
fitting_model_types = self.app_data.get("fitting_model_types")
fitting_model_types.append(self.model_fitter.model_type)
self.app_data["fitting_model_types"] = fitting_model_types
self.update_client()
[docs] def set_model_fitter(self, trace_name, fitting_model=FittingModels.GAUSSIAN_PLUS_LINEAR):
"""
Sets the instances of a model and fitter in order to perform model fitting on a trace.
:param model: model instance
:param fitter: fitter instance
:return:
"""
self.model_fitter = self._get_model_fitter(trace_name, self.app_data, fitting_model, self.app_data['selection'])
[docs] def fit_model(self, trace_name, median_filter_width=1, add_fit_substracted_trace=False):
"""
Parameters
----------
trace_name
median_filter_width
add_fit_substracted_trace
Returns
-------
"""
self._initialize_updates(self.app_data)
fitting_info_list = self._fit_model_to_flux([trace_name], self.app_data, [self.model_fitter],
self.app_data['selection'], median_filter_width=median_filter_width,
do_update_client=True,
add_fit_substracted_trace=add_fit_substracted_trace)
return fitting_info_list[0]
[docs] def get_data_selection(self):
"""
Returns
-------
"""
return self.app_data.get("selection")
[docs] def set_data_selection(self, trace_name, selection_indices=[]):
"""
Parameters
----------
trace_name
selection_indices
Returns
-------
"""
self._set_selection(trace_name, self.app_data, selection_indices, do_update_client=True)
[docs] def add_spectrum(self, spectrum, is_visible=True):
"""
Parameters
----------
spectrum
is_visible
Returns
-------
"""
self._initialize_updates(self.app_data)
_spectrum = spectrum if type(spectrum) == list else [spectrum]
_is_visible = is_visible if type(is_visible) == list else [is_visible]
if len(_spectrum) != len(_is_visible):
raise Exception("spectrum and is_visible parameters should have the same length")
_s = Spectrum()
for i, spectrum in enumerate(_spectrum):
if type(spectrum) != type(_s):
raise Exception("Invalid type of input spectrum parameter")
if type(_is_visible[i]) != bool:
raise Exception("Invalid type of input is_visible parameter")
added_traces = []
for i, spectrum in enumerate(_spectrum):
trace = Trace()
trace.from_spectrum(spectrum, is_visible=_is_visible[i])
trace = trace.to_dict()
wavelength_unit = self.app_data['axis_units'].get('wavelength_unit')
flux_unit = self.app_data['axis_units'].get('flux_unit')
if wavelength_unit is not None and flux_unit is not None:
trace = self._get_rescaled_axis_in_trace(trace, to_wavelength_unit=wavelength_unit,
to_flux_unit=flux_unit)
else:
self._set_axis_units(self.app_data, wavelength_unit, flux_unit)
added_traces.append(trace)
self._add_trace_to_data(self.app_data, added_traces, do_update_client=False)
# self._set_trace_updates_info(self.app_data, added_trace_names=[s.name for s in _spectrum])
self.update_client()
[docs] def get_spectrum(self, name):
"""
Parameters
----------
name
Returns
-------
"""
t = self.app_data['traces'][name]
s = Spectrum(name=name, wavelength=t['wavelength'], flux=t['flux'], flux_error=t['flux_error'],
masks=t['masks'],
mask_bits=t['mask_bits'], wavelength_unit=t['wavelength_unit'], flux_unit=t['flux_unit'],
catalog=t['catalog'], spectrum_type=t['spectrum_type'], color=t['color'], linewidth=t['linewidth'],
alpha=t['alpha'])
return s
[docs] def get_trace_names(self):
"""
Returns
-------
"""
return [name for name in self.app_data['traces']]
[docs] def add_spectrum_from_file(self, file_path, catalog_name, display_name=None, to_wavelength_unit=None,
to_flux_unit=None):
"""
Parameters
----------
file_path
catalog_name
display_name
to_wavelength_unit
to_flux_unit
Returns
-------
"""
if type(file_path) == str and (type(display_name) == str or display_name is None):
file_path = [file_path]
display_name = [display_name]
else:
if type(file_path) == list and display_name is None:
trace_name = [None for i in file_path]
if type(file_path) != list or (type(display_name) != list and display_name is not None):
raise Exception("Wrong type for file_path or trace_name parameters.")
if len(file_path) != len(display_name):
raise Exception("file_path and name should be lists of the same length")
for i in range(len(file_path)):
self._add_spectrum_from_file(file_path[i], self.app_data, to_wavelength_unit, to_flux_unit, catalog_name,
display_name[i], do_update_client=False)
self.update_client()
[docs] def add_spectrum_from_id(self, specid, catalog_name, display_name=None, to_wavelength_unit=None, to_flux_unit=None):
"""
Parameters
----------
specid
catalog_name
display_name
to_wavelength_unit
to_flux_unit
Returns
-------
"""
# self._initialize_updates(self.app_data)
if type(specid) == str and (type(display_name) == str or display_name is None):
specid = [specid]
display_name = [display_name]
else:
if type(specid) == list and display_name is None:
name = [None for i in specid]
if type(specid) != list or (type(display_name) != list and display_name is not None):
raise Exception("Wrong type for specid or name parameters.")
if len(specid) != len(display_name):
raise Exception("specid and name should be lists of the same length")
self._load_from_specid(specid, display_name, to_wavelength_unit, to_flux_unit, self.app_data, catalog_name,
do_update_client=False)
self.update_client()
[docs] def get_catalog_names(self):
"""
Returns
-------
"""
return input.get_supported_catalogs()
[docs] def update_trace(self, name, trace):
"""
Parameters
----------
name
trace
Returns
-------
"""
self.app_data['traces'][name] = trace
self.update_client()
[docs] def remove_trace(self, name, also_remove_children=True):
"""
Parameters
----------
name
also_remove_children
Returns
-------
"""
self._remove_traces([name], self.app_data, do_update_client=True, also_remove_children=also_remove_children)
[docs] def toggle(self, name=None, is_visible=True, all_traces=True):
"""
Parameters
----------
name
is_visible
all_traces
Returns
-------
"""
if type(name) == str:
name = [name]
if name is None and all_traces is True:
name = self.get_trace_names()
traces = self.app_data['traces']
for _name in name:
if _name not in traces:
raise Exception(_name + " not found in trace list")
trace = traces[_name]
trace["is_visible"] = True if is_visible is True else False
traces[_name] = trace
self.app_data['traces'] = traces
self.update_client()
[docs] def set_axis_units(self, wavelength_unit=WavelengthUnit.ANGSTROM, flux_unit=FluxUnit.F_lambda):
"""
Parameters
----------
wavelength_unit
flux_unit
Returns
-------
"""
self._rescale_axis(self.app_data, to_wavelength_unit=wavelength_unit, to_flux_unit=flux_unit)
self.update_client()