import logging
import os
from datetime import datetime
from glob import glob
from pathlib import Path
try:
import imaspy as imas
except ImportError:
import imas
import yaml
logger = logging.getLogger(f"module.{__name__}")
[docs]class DBMaster:
ALL_BACKENDS = "mdsplus", "hdf5"
[docs] @staticmethod
def get_user_dir(user: str = None):
"""
The function `get_user_dir` returns the database directory path for a given user or the current user's directory
path if no user is specified.
Args:
user (str): The `user` parameter is a string that represents the username of the user for whom the
directory path is being retrieved. If the `user` parameter is not provided or is `None`, it will
default to the current logged-in user obtained using `os.getlogin()`.
Returns:
a file path. If the user is not specified or is "public", it returns the file path to the
"public/imasdb/" directory in the user's home directory. If the user is not "public", itreturns the
file path to the "shared/imasdb/" directory in the IMAS_HOME directory.
"""
if not user:
user = os.getlogin()
if user != "public":
return f'{os.path.expanduser(f"~{user}")}/public/imasdb/'
imas_home_dir = os.environ["IMAS_HOME"]
if imas_home_dir is None:
raise FileNotFoundError("File path in the environment variable IMAS_HOME is not defined.")
return f"{imas_home_dir}/shared/imasdb/"
[docs] @staticmethod
def get_database_dir(database: str, user: str = None):
"""
The function `get_database_dir` returns the directory path for a given database, and raises an error
if the path does not exist.
Args:
database (str): The `database` parameter is a string that represents the name of the database
file or directory.
user (str): The `user` parameter is an optional parameter that represents the user for whom the
database directory is being retrieved.
Returns:
the directory path of the specified database if it exists. If the database does not exist, it
raises a FileNotFoundError. If the database parameter is None, it returns None.
"""
user_dir = DBMaster.get_user_dir(user)
if database is not None:
user_database_dir = user_dir + database
if os.path.exists(user_database_dir):
return user_database_dir
else:
raise FileNotFoundError(
"The path provided does not exist or has no such database file or directory. \
Please check spelling."
)
return None
[docs] @staticmethod
def get_version_dir(version: str, database: str, user: str = None):
"""
The function `get_version_dir` returns the directory path for a specific version of a database,
given the version, database name, and optional user.
Args:
version (str): The version parameter is a string that represents the version of the database.
database (str): The `database` parameter is a string that represents the name of the database.
user (str): The `user` parameter is an optional parameter
Returns:
the directory path for the specified version of a database. If the version directory exists,
it returns the path. If the version directory does not exist, it raises a FileNotFoundError.
If the version parameter is None, it returns None.
"""
database_dir = DBMaster.get_database_dir(database, user)
if version is not None:
version_dir = f"{database_dir}/{version}"
if os.path.exists(version_dir):
return version_dir
return None
[docs] @staticmethod
def get_databases(user: str = None) -> list:
"""
The function `get_databases` returns a sorted list of databases in a user's directory.
Args:
user (str): The `user` parameter is a string that represents the username of the user
for whom the databases are being retrieved.
Returns:
a list of databases.
"""
user_dir = DBMaster.get_user_dir(user)
databases = [
_database for _database in os.listdir(user_dir) if os.path.isdir(os.path.join(user_dir, _database))
]
return sorted(databases)
[docs] @staticmethod
def get_versions(database: str, user: str = None) -> list:
"""
The function `get_versions` returns a sorted list of versions in a given database directory.
Args:
database (str): A string representing the name of the database.
user (str): The `user` parameter is an optional parameter
Returns:
a sorted list of versions.
"""
database_dir = DBMaster.get_database_dir(database, user)
versions = [
_version for _version in os.listdir(database_dir) if os.path.isdir(os.path.join(database_dir, _version))
]
return sorted(versions)
[docs] @staticmethod
def get_databases_with_versions(user: str = None) -> list:
"""
The function `get_databases_with_versions` returns a list of tuples, where each tuple contains
the name of a database and a list of its versions, for a given user.
Args:
user (str): The `user` parameter is a string that represents the username or identifier of
the user for whom the databases and their versions are being retrieved. It is an optional
parameter and can be set to `None` if not applicable.
Returns:
a list of tuples. Each tuple contains the name of a database and a list of versions associated
with that database. The list is sorted by the database names.
"""
user_dir = DBMaster.get_user_dir(user)
databases_dict = {}
for _database in os.listdir(user_dir):
if not os.path.isdir(os.path.join(user_dir, _database)):
continue
_database_versions = DBMaster.get_versions(_database, user)
databases_dict[_database] = _database_versions
return [(database, databases_dict[database]) for database in sorted(databases_dict.keys())]
[docs] @staticmethod
def get_versions_with_databases(user: str = None) -> list:
"""
The function `get_versions_with_databases` returns a list of tuples, where each tuple contains a version
number and a list of databases associated with that version.
Args:
user (str): The `user` parameter is an optional string
Returns:
a list of tuples. Each tuple contains a version number and a list of databases that have that version.
The list is sorted in ascending order based on the version numbers.
"""
database_with_versions_dict = DBMaster.get_databases_with_versions(user=user)
database_dict = {}
for database, versions in database_with_versions_dict:
for _version in versions:
if _version not in database_dict:
database_dict[_version] = []
database_dict[_version].append(database)
return [(version, database_dict[version]) for version in sorted(database_dict.keys())]
[docs] @staticmethod
def get_hdf5_pulses(
user: str = None, database: str = None, version: str = None, status=None, as_dictionary=False
) -> list:
"""
The function `get_hdf5_pulses` retrieves a list of pulses from HDF5 master files. It needs to specify
full path till version.
Args:
user (str): The `user` parameter is a string that represents the user for whom the MDSPlus
pulses are being retrieved.
database (str): The `database` parameter is a string that represents the name of the database.
It is used to specify the directory where the MDSplus pulses are stored.
version (str): The `version` parameter is used to specify the version of the MDSplus database.
It is a string that represents the version number.
as_dictionary (bool): The `as_dictionary` parameter is a boolean flag that determines the format
of the returned pulses. If `as_dictionary` is set to `True`, the pulses will be returned as a
dictionary where the keys are the pulse numbers and the values are lists of runs associated
with each pulse.Defaults to False
Returns:
a list of tuples. Each tuple contains the following elements, The tuple includes the pulse number,
run number, HDF5_BACKEND backend, database, user, version, and data file path.
"""
pulses = {} if as_dictionary else []
version_dir = DBMaster.get_version_dir(version, database, user)
if version_dir is None:
return pulses
scenario_yaml_dir = os.path.join(version_dir, "0")
hdf5_master_file_paths = glob(f"{version_dir}/**/*master.h5", recursive=True)
for hdf5_master_file_path in hdf5_master_file_paths:
run = hdf5_master_file_path.split("/")[-2]
if not run.isdigit():
print(f"warning:run number is not an integer {run} {hdf5_master_file_path}")
continue
run = int(run)
pulse = hdf5_master_file_path.split("/")[-3]
if not pulse.isdigit():
print(f"warning:pulse number is not an integer {pulse}/{run} {hdf5_master_file_path}")
continue
pulse = int(pulse)
file_time = datetime.fromtimestamp(os.path.getmtime(hdf5_master_file_path)).replace(microsecond=0)
if status is not None:
yaml_file = f"ids_{pulse}{str(run).zfill(4)}.yaml"
yaml_file_path = os.path.join(scenario_yaml_dir, yaml_file)
status_from_yaml = ""
if os.path.exists(yaml_file_path):
status_from_yaml = DBMaster.get_pulse_status(yaml_file_path)
if status_from_yaml == "":
print(f"warning:could not find status info in scenario file {pulse}/{run} {yaml_file_path}")
else:
print(f"warning:scenario summary file does not exists for {pulse}/{run} {yaml_file_path}")
if status != status_from_yaml:
continue
if as_dictionary:
if pulse not in pulses:
pulses[pulse] = []
pulses[pulse].append(
(
pulse,
run,
imas.ids_defs.HDF5_BACKEND,
database,
user,
version,
hdf5_master_file_path,
file_time,
)
)
else:
pulses.append(
(
pulse,
run,
imas.ids_defs.HDF5_BACKEND,
database,
user,
version,
hdf5_master_file_path,
file_time,
)
)
return pulses
[docs] @staticmethod
def get_hdf5_pulses_from_folder(folder: str = None, as_dictionary=False) -> list:
pulses = {} if as_dictionary else []
hdf5_master_file_paths = glob(f"{folder}/**/*master.h5", recursive=True)
for hdf5_master_file_path in hdf5_master_file_paths:
run = hdf5_master_file_path.split("/")[-2]
if run.isdigit():
run = int(run)
else:
run = 0
pulse = hdf5_master_file_path.split("/")[-3]
if pulse.isdigit():
pulse = int(pulse)
else:
pulse = 0
file_time = datetime.fromtimestamp(os.path.getmtime(hdf5_master_file_path)).replace(microsecond=0)
if as_dictionary:
if hdf5_master_file_path not in pulses:
pulses[hdf5_master_file_path] = []
pulses[hdf5_master_file_path].append(
(
pulse,
run,
imas.ids_defs.HDF5_BACKEND,
hdf5_master_file_path,
file_time,
)
)
else:
pulses.append(
(
pulse,
run,
imas.ids_defs.HDF5_BACKEND,
hdf5_master_file_path,
file_time,
)
)
return pulses
[docs] @staticmethod
def get_mds_plus_pulses(
user: str = None,
database: str = None,
version: str = None,
status: str = None,
as_dictionary=False,
) -> list:
"""
The function `get_mds_plus_pulses` retrieves a list of MDSPlus pulses based on the provided user, database,
version, and status parameters.
Args:
user (str): The `user` parameter is a string that represents the user for whom the MDSPlus
pulses are being retrieved.
database (str): The `database` parameter is a string that represents the name of the database.
It is used to specify the directory where the MDSplus pulses are stored.
version (str): The `version` parameter is used to specify the version of the MDSplus database.
It is a string that represents the version number.
status (str): The "status" parameter is used to filter the pulses based on their status. If a
status is provided, only pulses with that status will be included in the result. If no status
is provided, all pulses will be included.
as_dictionary (bool): The `as_dictionary` parameter is a boolean flag that determines the format
of the returned pulses. If `as_dictionary` is set to `True`, the pulses will be returned as a
dictionary where the keys are the pulse numbers and the values are lists of runs associated
with each pulse.Defaults to False
Returns:
a list of pulses.
"""
pulses = {} if as_dictionary else []
mdsplus_dir = DBMaster.get_version_dir(version, database, user)
if mdsplus_dir is None:
return pulses
scenario_yaml_dir = os.path.join(mdsplus_dir, "0")
datafile_paths = glob(f"{mdsplus_dir}/**/*.datafile", recursive=True)
for data_file_path in datafile_paths:
root = os.path.dirname(data_file_path)
datafile = os.path.basename(data_file_path)
run_list = (root[len(mdsplus_dir) + 1 :]).split("/")
if len(run_list) == 1: # AL4 layout
num_start_pos = datafile.find("_") + 1
num_end_pos = datafile.rfind(".")
num = int(datafile[num_start_pos:num_end_pos])
pulse = num // 10000
if not run_list[0].isdigit():
print(f"warning:run number is not an integer {run_list[0]} {data_file_path}")
continue
run = int(run_list[0]) * 10000 + (num % 10000)
else: # AL5 layout
if datafile != "ids_001.datafile":
print(f"warning:ids_001.datafile does not exists {data_file_path}")
continue
if os.path.islink(data_file_path):
continue
run = root.split("/")[-1]
if not run.isdigit():
print(f"warning:run number is not an integer {run} {data_file_path}")
continue
run = int(run)
pulse = root.split("/")[-2]
if not pulse.isdigit():
print(f"warning:pulse number is not an integer {pulse}/{run} {data_file_path}")
continue
pulse = int(pulse)
if status is not None:
yaml_file = f"ids_{pulse}{str(run).zfill(4)}.yaml"
yaml_file_path = os.path.join(scenario_yaml_dir, yaml_file)
status_from_yaml = ""
if os.path.exists(yaml_file_path):
status_from_yaml = DBMaster.get_pulse_status(yaml_file_path)
if status_from_yaml == "":
print(f"warning:could not find status info in scenario file {pulse}/{run} {yaml_file_path}")
else:
print(f"warning:scenario summary file does not exists for {pulse}/{run} {yaml_file_path}")
if status != status_from_yaml:
continue
file_time = datetime.fromtimestamp(os.path.getmtime(data_file_path)).replace(microsecond=0)
if as_dictionary:
if pulse not in pulses:
pulses[pulse] = []
is_run_available = any(x[1] == run for x in pulses[pulse])
if not is_run_available:
pulses[pulse].append(
(
pulse,
run,
imas.ids_defs.MDSPLUS_BACKEND,
database,
user,
version,
data_file_path,
file_time,
)
)
else:
pulses.append(
(
pulse,
run,
imas.ids_defs.MDSPLUS_BACKEND,
database,
user,
version,
data_file_path,
file_time,
)
)
return pulses
[docs] @staticmethod
def get_mds_plus_pulses_from_folder(
folder: str = None,
as_dictionary=False,
) -> list:
pulses = {} if as_dictionary else []
datafile_paths = glob(f"{folder}/**/*.datafile", recursive=True)
for data_file_path in datafile_paths:
root = os.path.dirname(data_file_path)
datafile = os.path.basename(data_file_path)
run_list = (root[len(folder) + 1 :]).split("/")
if len(run_list) == 1 and run_list[0] != "": # AL4 layout
num_start_pos = datafile.find("_") + 1
num_end_pos = datafile.rfind(".")
num = int(datafile[num_start_pos:num_end_pos])
pulse = num // 10000
run = 0
if run_list[0].isdigit():
run = int(run_list[0]) * 10000 + (num % 10000)
else: # AL5 layout
if datafile != "ids_001.datafile":
continue
run = root.split("/")[-1]
if not run.isdigit():
run = int(run)
else:
run = 0
pulse = root.split("/")[-2]
if not pulse.isdigit():
pulse = int(pulse)
else:
pulse = 0
try:
file_time = datetime.fromtimestamp(os.path.getmtime(data_file_path)).replace(microsecond=0)
except FileNotFoundError:
print(f"warning:invalid file {data_file_path}")
continue
if as_dictionary:
if data_file_path not in pulses:
pulses[data_file_path] = []
is_run_available = any(x[1] == run for x in pulses[data_file_path])
if not is_run_available:
pulses[data_file_path].append(
(
pulse,
run,
imas.ids_defs.MDSPLUS_BACKEND,
data_file_path,
file_time,
)
)
else:
pulses.append(
(
pulse,
run,
imas.ids_defs.MDSPLUS_BACKEND,
data_file_path,
file_time,
)
)
return pulses
[docs] @staticmethod
def get_pulse_status(yaml_file_path) -> str:
"""
The function `get_pulse_status` reads a YAML file from a given path and returns the value of the
"status" key in the file's metadata.
Args:
yaml_file_path: The `path` parameter is a string that represents the file path to a YAML file.
Returns:
the value of the "status" key from the metadata dictionary.
"""
_yaml_file_path = Path(yaml_file_path)
status = ""
with open(_yaml_file_path, "r") as file_handle:
lines = file_handle.readlines()
for i, line in enumerate(lines):
if line.strip().startswith("status:"):
start_index = max(0, i - 1)
end_index = min(len(lines), i + 2)
context = lines[start_index:end_index]
combined_context = "".join(context)
metadata = yaml.load(combined_context, Loader=yaml.Loader)
if isinstance(metadata, dict):
status = metadata["status"]
return status
[docs] @staticmethod
def get_database_files(user=None, database=None, version=None, backends=None):
"""
The function `get_database_files` retrieves a list of database files based on the specified user,
database, version, and backends.
Args:
user: The ``user`` parameter is used to specify the user for whom the database files are being
retrieved. If no user is specified, it defaults to ``None``.
database: The ``database`` parameter is used to specify the name of the database.
version: The ``version`` parameter is used to specify a specific version of the database.
backends: The ``backends`` parameter is a list of strings that specifies the database backends to
retrieve files from. The possible values for ``backends`` are ``hdf5`` and ``mdsplus``. If ``backends``
is not provided, it defaults to ``DBMaster.ALL_BACKENDS``
Returns:
The function ``get_database_files`` returns a list of tuples. Each tuple contains the name of a database,
followed by a list of tuples. Each inner tuple contains a version number, followed by a list of tuples.
Each innermost tuple contains the name of a backend (either ``hdf5`` or ``mdsplus``), followed by a
dictionary of database files.
"""
result = []
if not backends:
backends = DBMaster.ALL_BACKENDS
databases = [database] if database else DBMaster.get_databases(user)
for database in databases:
database_files = []
versions = [version] if version else DBMaster.get_versions(database, user)
for _version in versions:
pulses = []
for backend in backends:
if backend == "hdf5":
dbs = DBMaster.get_hdf5_pulses(user, database, _version, as_dictionary=True)
elif backend == "mdsplus":
dbs = DBMaster.get_mds_plus_pulses(user, database, _version, as_dictionary=True)
else:
raise NotImplementedError(f"Unsupported backend: {backend}")
if dbs:
pulses.append((backend, dbs))
if pulses:
database_files.append((_version, pulses))
if database_files:
result.append((database, database_files))
return result
[docs] @staticmethod
def get_database_files_from_folder(folder=None, backends=None):
"""
Retrieve database files from a folder based on specified backends.
Parameters
----------
folder : str, optional
The folder path from which to retrieve database files.
backends : list, optional
A list of strings specifying the database backends to retrieve files from.
The possible values for backends are 'hdf5' and 'mdsplus'. If backends
is not provided, it defaults to DBMaster.ALL_BACKENDS
Returns
-------
list
A list of tuples. Each tuple contains the backend name (either 'hdf5' or 'mdsplus'),
followed by a dictionary of database files.
"""
if not backends:
backends = DBMaster.ALL_BACKENDS
pulses = []
for backend in backends:
if backend == "hdf5":
dbs = DBMaster.get_hdf5_pulses_from_folder(folder, as_dictionary=True)
elif backend == "mdsplus":
dbs = DBMaster.get_mds_plus_pulses_from_folder(folder, as_dictionary=True)
else:
raise NotImplementedError(f"Unsupported backend: {backend}")
if dbs:
pulses.append((backend, dbs))
return pulses
[docs] @staticmethod
def get_hdf5_physical_file(user, database, version, pulse, run):
"""
The function `get_hdf5_physical_file` returns the path to an HDF5 file based on the user, database, version,
pulse, and run.
Args:
user: The "user" parameter represents the name of the user who is accessing the HDF5 physical file.
database: The "database" parameter refers to the name of the database where the HDF5 files are stored.
version: The "version" parameter represents the version of the database.
pulse: The "pulse" parameter represents the pulse number. It is a numerical value that identifies a
specific pulse in a dataset.
run: The "run" parameter represents the run number.
Returns:
the path to an HDF5 physical file.
"""
hdf5dir = os.path.join(DBMaster.get_user_dir(user), database, version, "hdf5")
return os.path.join(hdf5dir, f"ids_{str(pulse)}_{str(run)}.hd5")
[docs] @staticmethod
def get_mdsplus_physical_files(user, database, version, pulse, run):
"""
The function `get_mdsplus_physical_files` returns the MDS+ database filenames for a given IMAS
database.
Args:
user: The "user" parameter is the username of the user accessing the IMAS database.
database: The `database` parameter refers to the name of the IMAS database.
version: The "version" parameter represents the version of the IMAS database.
pulse: The parameter "pulse" represents the pulse number in the IMAS database.
run: The "run" parameter is the run number.
Returns:
The function `get_mdsplus_physical_files` returns a tuple of three strings. The first string is the
filename with the extension ".characteristics", the second string is the filename with the extension
".datafile", and the third string is the filename with the extension ".tree".
"""
mdsplusdir = os.path.join(DBMaster.get_user_dir(user), database, version)
# filename is ids_<shot><run> where run is last four digits of run number,
# right-aligned (filled with zeros).
# Examples: 1
run_string = str(run % 10000)
if pulse == 0:
mdsplus_file_name = os.path.join(mdsplusdir, str(int(run / 10000)), f"ids_{run_string.zfill(3)}")
else:
mdsplus_file_name = os.path.join(
mdsplusdir,
str(int(run / 10000)),
f"ids_{str(pulse)}{run_string.zfill(4)}",
)
return (
f"{mdsplus_file_name}.characteristics",
f"{mdsplus_file_name}.datafile",
f"{mdsplus_file_name}.tree",
)
[docs] @staticmethod
def get_physical_files(user, database, version, pulse, run, backend):
"""
The function `get_physical_files` returns the physical files storing a database based on the
specified backend.
Args:
user: The user parameter represents the user who is requesting the physical files.
database: The "database" parameter refers to the name or identifier of the database for which you
want to retrieve the physical files.
version: The version parameter represents the version of the database. It is used to retrieve the
physical files associated with a specific version of the database.
pulse: The "pulse" parameter refers to a specific pulse or shot number in a database. It is used to
identify a particular data acquisition event or experiment.
run: The "run" parameter is used to specify the run number or identifier for the database. It is
likely used to retrieve the physical files associated with a specific run of the database.
backend: The "backend" parameter refers to the type of database backend being used. It can have
two possible values: "mdsplus" or "hdf5".
Returns:
The function `get_physical_files` returns the physical file path storing the specified database.
"""
"""Return files storing this database."""
if backend == "mdsplus":
return DBMaster.get_mdsplus_physical_files(user, database, version, pulse, run)
elif backend == "hdf5":
return DBMaster.get_hdf5_physical_file(user, database, version, pulse, run)
else:
raise NotImplementedError(f"Unsupported backend: {backend}")
[docs] @classmethod
def get_dd_version(cls):
factory = imas.ids_factory.IDSFactory()
return factory.dd_version
[docs] @classmethod
def create_connection(cls, imasargs, target_dd_version=None):
if "mode" not in imasargs.__dict__:
imasargs.mode = "w"
connection = None
if imasargs.uri != "" and imasargs.uri is not None:
connection = imas.DBEntry(imasargs.uri, imasargs.mode, dd_version=target_dd_version)
return connection
[docs] @classmethod
def get_connection(cls, imasargs):
connection = None
if imasargs.uri != "" and imasargs.uri is not None:
if "mode" in imasargs.__dict__:
connection = imas.DBEntry(imasargs.uri, imasargs.mode)
else:
try:
connection = imas.DBEntry(imasargs.uri, "r")
except Exception as e:
print(e)
return connection
[docs] @staticmethod
def pulse_list2_dict(pulselist):
"""Utility function that returns a dict from a list of pairs (pulse,run)
Parameters
----------
pulselist: list of tuples
List of tuples (pulse,run)
Returns
-------
dict key=pulse:value=[runs]
"""
pulsedict = {}
for pulse, run in pulselist:
pulsedict.setdefault(pulse, []).append(run)
return pulsedict
[docs] @staticmethod
def mds_list_pulse_run(locpath, with_status=None, as_dict=False):
"""Function that lists Pulse and Run numbers from a given database, in MDSPLUS
Parameters
----------
locpath: str or Path
Path in which the database files are stored
with_status: str
If set, will list only pulses with given status (in associated yaml file, e.g. 'obsolete', 'active')
Returns
-------
list of tuple (pulse,run)
"""
locpath = Path(locpath).expanduser()
if not locpath.exists():
raise FileNotFoundError(
"The path provided does not exist or has no such database file or directory. Please check spelling."
)
pulses = []
# folder = Path(locpath).glob('**/*.datafile') # --> does not work with
# linked subfolders (https://bugs.python.org/issue33428)
folder = glob(str(locpath) + "/**/*.datafile", recursive=True)
for entry in folder:
if (with_status is None) or (with_status == DBMaster.get_pulse_status(Path(entry).with_suffix(".yaml"))):
file = entry.split("/")[-1].split("_")[1].split(".")[0]
if len(file) <= 4:
pulse = int(entry.split("/")[-3])
run = int(entry.split("/")[-2])
else:
pulse = int(file[0:-4])
run = int(file[-4:])
# run = int(file[-4:]) + 10000 * int(entry.split("/")[-2])
pulses.append((pulse, run))
pulses_set = set(pulses)
return list(pulses_set)
[docs] @staticmethod
def hdf5_list_pulse_run(locpath):
"""Function that lists Pulse and Run numbers from a given database, in HDF5
Parameter
---------
locpath: str or Path
Path in which the database files are stored
Returns
-------
list of tuple (pulse,run)
"""
locpath = Path(locpath).expanduser()
if not locpath.exists():
raise FileNotFoundError(
"The path provided does not exist or has no such database file or directory. Please check spelling."
)
pulses = []
# folder = Path(locpath).glob('**/*master.h5')
folder = glob(str(locpath) + "/**/*master.h5", recursive=True)
for entry in folder:
_pulse = pulse = str(entry).split("/")[-3]
_run = run = str(entry).split("/")[-2]
if _pulse.isdigit():
pulse = int(str(entry).split("/")[-3])
if _run.isdigit():
run = int(str(entry).split("/")[-2])
pulses.append((pulse, run))
return pulses
[docs] @staticmethod
def get_db_Path(user, database, version):
"""Function that returns a pathlib Path to desired database, depending on the user, database and
version names.
Parameters
---------
user: str
Status of user: either public or local. A public user should just be left as public, whereas a
local user should write their proper identifier
database: str
Name of database where the data is harbored
version: str
String of number of data version
Returns
-------
pathlib.Path
"""
if user == "public":
locpath = Path(os.environ["IMAS_HOME"] + "/shared/imasdb/" + database + "/" + version)
else:
locpath = Path(os.path.expanduser("~" + user) + "/public/imasdb/" + database + "/" + version)
return locpath
[docs]def read_scenario(
scenario_file_path: str,
in_ids_list: list = None,
out_ids_list: list = None,
test_mode: bool = False,
**test_args,
):
"""
This function reads a scenario file and takes in optional input and output IDs lists, as well as a test
mode flag and additional test arguments.
Args:
scenario_file_path (str): The file path of the scenario file that contains the test cases.
in_ids_list (list): A list of input IDS names that should be read from the scenario file.
out_ids_list (list): A list of output IDS names It is used to specify the list of output IDs that
the function should read from the scenario file. If this parameter is not provided, the function
will read all output IDs from the scenario file.
test_mode (bool): A boolean flag indicating whether the function is being called in test mode or not.
If test_mode is True, the function will execute in a way that is suitable for testing purposes.
Defaults to False
"""
test_args_list = list(test_args.values())
in_ids_dict = {}
out_ids_dict = {}
if in_ids_list is None:
in_ids_list = []
if out_ids_list is None:
out_ids_list = []
with open(scenario_file_path, "r") as scenario_file:
config = yaml.load(scenario_file, Loader=yaml.Loader)
# Read the equilibrium and core_profiles IDSs from the input datafile
connection_in = imas.DBEntry(
imas.ids_defs.MDSPLUS_BACKEND,
config["input_database"],
config["shot"],
config["run_in"],
config["input_user_or_path"],
)
connection_in.open()
for ids_name in in_ids_list:
if test_mode:
ids = connection_in.get_slice(ids_name, test_args_list)
else:
ids = connection_in.get(ids_name)
in_ids_dict[ids_name] = ids
connection_in.close()
# Read the out IDS from the output datafile
connection_out = imas.DBEntry(
imas.ids_defs.MDSPLUS_BACKEND,
config["output_database"],
config["shot"],
config["run_out"],
(os.getenv("USER") if config["output_user_or_path"] == "default" else config["output_user_or_path"]),
)
# print(config["output_database"])
# print(config["shot"])
# print(config["run_out"])
# print(config["output_user_or_path"])
connection_out.open()
for ids_name in out_ids_list:
if test_mode:
ids = connection_out.get_slice(ids_name, test_args_list)
else:
ids = connection_out.get(ids_name)
out_ids_dict[ids_name] = ids
connection_out.close()
import argparse
inputargs = argparse.Namespace()
inputargs.backend = imas.ids_defs.MDSPLUS_BACKEND
inputargs.pulse = config["shot"]
inputargs.run = config["run_in"]
inputargs.user = config["input_user_or_path"]
inputargs.database = config["input_database"]
inputargs.version = 3
inputargs.uri = None
return in_ids_dict, out_ids_dict, inputargs
[docs]def read_scenario_with_args(
imasargs,
in_ids_list: list = None,
out_ids_list: list = None,
test_mode: bool = False,
**test_args,
):
"""
This function reads a scenario file and takes in optional input and output IDs lists, as well as a
test mode flag and additional test arguments.
Args:
imasargs (str): The file path of the scenario file that contains the test cases.
in_ids_list (list): A list of input IDS names that should be read from the scenario file.
out_ids_list (list): A list of output IDS names It is used to specify the list of output IDs that the
function should read from the scenario file. If this parameter is not provided, the function will read
all output IDs from the scenario file.
test_mode (bool): A boolean flag indicating whether the function is being called in test mode or not.
If test_mode is True, the function will execute in a way that is suitable for testing purposes.
Defaults to False
"""
test_args_list = list(test_args.values())
in_ids_dict = {}
out_ids_dict = {}
if in_ids_list is None:
in_ids_list = []
if out_ids_list is None:
out_ids_list = []
connection = DBMaster.get_connection(imasargs)
if connection is None:
return None
for ids_name in in_ids_list:
if test_mode:
ids = connection.get_slice(ids_name, test_args_list)
else:
ids = connection.get(ids_name)
in_ids_dict[ids_name] = ids
for ids_name in out_ids_list:
if test_mode:
ids = connection.get_slice(ids_name, test_args_list)
else:
ids = connection.get(ids_name)
out_ids_dict[ids_name] = ids
connection.close()
return in_ids_dict, out_ids_dict