diff --git a/src/mavedb/lib/script_environment.py b/src/mavedb/lib/script_environment.py deleted file mode 100644 index d81e909d0..000000000 --- a/src/mavedb/lib/script_environment.py +++ /dev/null @@ -1,25 +0,0 @@ -""" -Environment setup for scripts. -""" - -from sqlalchemy.orm import Session, configure_mappers - -from mavedb import deps -from mavedb.models import * # noqa: F403 - - -def init_script_environment() -> Session: - """ - Set up the environment for a script that may be run from the command line and does not necessarily depend on the - FastAPI framework. - - Features: - - Configures logging for the script. - - Loads the SQLAlchemy data model. - - Returns an SQLAlchemy database session. - """ - # Scan all our model classes and create backref attributes. Otherwise, these attributes only get added to classes once - # an instance of the related class has been created. - configure_mappers() - - return next(deps.get_db()) diff --git a/src/mavedb/models/mapped_variant.py b/src/mavedb/models/mapped_variant.py index 5a418b22e..57cefd030 100644 --- a/src/mavedb/models/mapped_variant.py +++ b/src/mavedb/models/mapped_variant.py @@ -14,8 +14,8 @@ class MappedVariant(Base): id = Column(Integer, primary_key=True) - pre_mapped = Column(JSONB, nullable=True) - post_mapped = Column(JSONB, nullable=True) + pre_mapped = Column(JSONB(none_as_null=True), nullable=True) + post_mapped = Column(JSONB(none_as_null=True), nullable=True) vrs_version = Column(String, nullable=True) error_message = Column(String, nullable=True) modification_date = Column(Date, nullable=False, default=date.today, onupdate=date.today) diff --git a/src/mavedb/scripts/environment.py b/src/mavedb/scripts/environment.py new file mode 100644 index 000000000..f773f55ff --- /dev/null +++ b/src/mavedb/scripts/environment.py @@ -0,0 +1,160 @@ +""" +Environment setup for scripts. +""" + +import enum +import logging +import click +from functools import wraps + + +from sqlalchemy.orm import configure_mappers + +from mavedb import deps +from mavedb.models import * # noqa: F403 + + +logger = logging.getLogger(__name__) + + +@enum.unique +class DatabaseSessionAction(enum.Enum): + """ + Enum representing the database session transaction action selected for a + command decorated by :py:func:`.with_database_session`. + + You will not need to use this class unless you provide ``pass_action = + True`` to :py:func:`.with_database_session`. + """ + + DRY_RUN = "rollback" + PROMPT = "prompt" + COMMIT = "commit" + + +@click.group() +def script_environment(): + """ + Set up the environment for a script that may be run from the command line and does not necessarily depend on the + FastAPI framework. + + Features: + - Configures logging for the script. + - Loads the SQLAlchemy data model. + """ + + logging.basicConfig() + + # Un-comment this line to log all database queries: + logging.getLogger("__main__").setLevel(logging.INFO) + # logging.getLogger("sqlalchemy.engine").setLevel(logging.INFO) + + # Scan all our model classes and create backref attributes. Otherwise, these attributes only get added to classes once + # an instance of the related class has been created. + configure_mappers() + + +def with_database_session(command=None, *, pass_action: bool = False): + """ + Decorator to provide database session and error handling for a *command*. + + The *command* callable must be a :py:class:`click.Command` instance. + + The decorated *command* is called with a ``db`` keyword argument to provide + a :class:`~id3c.db.session.DatabaseSession` object. The call happens + within an exception handler that commits or rollsback the database + transaction, possibly interactively. Three new options are added to the + *command* (``--dry-run``, ``--prompt``, and ``--commit``) to control this + behaviour. + + >>> @click.command + ... @with_database_session + ... def cmd(db: DatabaseSession): + ... pass + + If the optional, keyword-only argument *pass_action* is ``True``, then the + :py:class:`.DatabaseSessionAction` selected by the CLI options above is + passed as an additional ``action`` argument to the decorated *command*. + + >>> @click.command + ... @with_database_session(pass_action = True) + ... def cmd(db: DatabaseSession, action: DatabaseSessionAction): + ... pass + + One example where this is useful is when the *command* accesses + non-database resources and wants to extend dry run mode to them as well. + """ + + def decorator(command): + @click.option( + "--dry-run", + "action", + help="Only go through the motions of changing the database (default)", + flag_value=DatabaseSessionAction("rollback"), + type=DatabaseSessionAction, + default=True, + ) + @click.option( + "--prompt", + "action", + help="Ask if changes to the database should be saved", + flag_value=DatabaseSessionAction("prompt"), + type=DatabaseSessionAction, + ) + @click.option( + "--commit", + "action", + help="Save changes to the database", + flag_value=DatabaseSessionAction("commit"), + type=DatabaseSessionAction, + ) + @wraps(command) + def decorated(*args, action, **kwargs): + db = next(deps.get_db()) + + kwargs["db"] = db + + if pass_action: + kwargs["action"] = action + + processed_without_error = None + + try: + command(*args, **kwargs) + + except Exception as error: + processed_without_error = False + + logger.error(f"Aborting with error: {error}") + raise error from None + + else: + processed_without_error = True + + finally: + if action is DatabaseSessionAction.PROMPT: + ask_to_commit = ( + "Commit all changes?" + if processed_without_error + else "Commit successfully processed records up to this point?" + ) + + commit = click.confirm(ask_to_commit) + else: + commit = action is DatabaseSessionAction.COMMIT + + if commit: + logger.info( + "Committing all changes" + if processed_without_error + else "Committing successfully processed records up to this point" + ) + db.commit() + + else: + logger.info("Rolling back all changes; the database will not be modified") + db.rollback() + + return decorated + + return decorator(command) if command else decorator diff --git a/src/mavedb/scripts/export_public_data.py b/src/mavedb/scripts/export_public_data.py index 705e79815..4a52ee808 100644 --- a/src/mavedb/scripts/export_public_data.py +++ b/src/mavedb/scripts/export_public_data.py @@ -34,17 +34,16 @@ from fastapi.encoders import jsonable_encoder from sqlalchemy import select -from sqlalchemy.orm import lazyload +from sqlalchemy.orm import lazyload, Session from mavedb.lib.score_sets import get_score_set_counts_as_csv, get_score_set_scores_as_csv -from mavedb.lib.script_environment import init_script_environment from mavedb.models.experiment import Experiment from mavedb.models.experiment_set import ExperimentSet from mavedb.models.license import License from mavedb.models.score_set import ScoreSet from mavedb.view_models.experiment_set import ExperimentSetPublicDump -db = init_script_environment() +from mavedb.scripts.environment import script_environment, with_database_session logger = logging.getLogger(__name__) @@ -89,68 +88,73 @@ def flatmap(f: Callable[[S], Iterable[T]], items: Iterable[S]) -> Iterable[T]: return chain.from_iterable(map(f, items)) -logger.info("Fetching data sets") - -experiment_sets_query = db.scalars( - select(ExperimentSet) - .where(ExperimentSet.published_date.is_not(None)) - .options( - lazyload(ExperimentSet.experiments.and_(Experiment.published_date.is_not(None))).options( - lazyload( - Experiment.score_sets.and_( - ScoreSet.published_date.is_not(None), ScoreSet.license.has(License.short_name == "CC0") +@script_environment.command() +@with_database_session +def export_public_data(db: Session): + experiment_sets_query = db.scalars( + select(ExperimentSet) + .where(ExperimentSet.published_date.is_not(None)) + .options( + lazyload(ExperimentSet.experiments.and_(Experiment.published_date.is_not(None))).options( + lazyload( + Experiment.score_sets.and_( + ScoreSet.published_date.is_not(None), ScoreSet.license.has(License.short_name == "CC0") + ) ) ) ) + .execution_options(populate_existing=True) + .order_by(ExperimentSet.urn) + ) + + # Filter the stream of experiment sets to exclude experiments and experiment sets with no public, CC0-licensed score + # sets. + experiment_sets = list(filter_experiment_sets(experiment_sets_query.all())) + + # TODO To support very large data sets, we may want to use custom code for JSON-encoding an iterator. + # Issue: https://github.com/VariantEffect/mavedb-api/issues/192 + # See, for instance, https://stackoverflow.com/questions/12670395/json-encoding-very-long-iterators. + + experiment_set_views = list(map(lambda es: ExperimentSetPublicDump.from_orm(es), experiment_sets)) + + # Get a list of IDS of all the score sets included. + score_set_ids = list( + flatmap(lambda es: flatmap(lambda e: map(lambda ss: ss.id, e.score_sets), es.experiments), experiment_sets) ) - .execution_options(populate_existing=True) - .order_by(ExperimentSet.urn) -) - -# Filter the stream of experiment sets to exclude experiments and experiment sets with no public, CC0-licensed score -# sets. -experiment_sets = list(filter_experiment_sets(experiment_sets_query.all())) - -# TODO To support very large data sets, we may want to use custom code for JSON-encoding an iterator. -# Issue: https://github.com/VariantEffect/mavedb-api/issues/192 -# See, for instance, https://stackoverflow.com/questions/12670395/json-encoding-very-long-iterators. - -experiment_set_views = list(map(lambda es: ExperimentSetPublicDump.from_orm(es), experiment_sets)) - -# Get a list of IDS of all the score sets included. -score_set_ids = list( - flatmap(lambda es: flatmap(lambda e: map(lambda ss: ss.id, e.score_sets), es.experiments), experiment_sets) -) - -timestamp_format = "%Y%m%d%H%M%S" -zip_file_name = f"mavedb-dump.{datetime.now().strftime(timestamp_format)}.zip" - -logger.info(f"Exporting public data set metadata to {zip_file_name}/main.json") -json_data = { - "title": "MaveDB public data", - "asOf": datetime.now(timezone.utc).isoformat(), - "experimentSets": experiment_set_views, -} - -with ZipFile(zip_file_name, "w") as zipfile: - # Write metadata for all data sets to a single JSON file. - zipfile.writestr("main.json", json.dumps(jsonable_encoder(json_data))) - - # Copy the CC0 license. - zipfile.write(os.path.join(os.path.dirname(__file__), "resources/CC0_license.txt"), "LICENSE.txt") - - # Write score and count files for each score set. - num_score_sets = len(score_set_ids) - for i, score_set_id in enumerate(score_set_ids): - score_set = db.scalars(select(ScoreSet).where(ScoreSet.id == score_set_id)).one_or_none() - if score_set is not None and score_set.urn is not None: - logger.info(f"{i + 1}/{num_score_sets} Exporting variants for score set {score_set.urn}") - csv_filename_base = score_set.urn.replace(":", "-") - - csv_str = get_score_set_scores_as_csv(db, score_set) - zipfile.writestr(f"csv/{csv_filename_base}.scores.csv", csv_str) - - count_columns = score_set.dataset_columns["count_columns"] if score_set.dataset_columns else None - if count_columns and len(count_columns) > 0: - csv_str = get_score_set_counts_as_csv(db, score_set) - zipfile.writestr(f"csv/{csv_filename_base}.counts.csv", csv_str) + + timestamp_format = "%Y%m%d%H%M%S" + zip_file_name = f"mavedb-dump.{datetime.now().strftime(timestamp_format)}.zip" + + logger.info(f"Exporting public data set metadata to {zip_file_name}/main.json") + json_data = { + "title": "MaveDB public data", + "asOf": datetime.now(timezone.utc).isoformat(), + "experimentSets": experiment_set_views, + } + + with ZipFile(zip_file_name, "w") as zipfile: + # Write metadata for all data sets to a single JSON file. + zipfile.writestr("main.json", json.dumps(jsonable_encoder(json_data))) + + # Copy the CC0 license. + zipfile.write(os.path.join(os.path.dirname(__file__), "resources/CC0_license.txt"), "LICENSE.txt") + + # Write score and count files for each score set. + num_score_sets = len(score_set_ids) + for i, score_set_id in enumerate(score_set_ids): + score_set = db.scalars(select(ScoreSet).where(ScoreSet.id == score_set_id)).one_or_none() + if score_set is not None and score_set.urn is not None: + logger.info(f"{i + 1}/{num_score_sets} Exporting variants for score set {score_set.urn}") + csv_filename_base = score_set.urn.replace(":", "-") + + csv_str = get_score_set_scores_as_csv(db, score_set) + zipfile.writestr(f"csv/{csv_filename_base}.scores.csv", csv_str) + + count_columns = score_set.dataset_columns["count_columns"] if score_set.dataset_columns else None + if count_columns and len(count_columns) > 0: + csv_str = get_score_set_counts_as_csv(db, score_set) + zipfile.writestr(f"csv/{csv_filename_base}.counts.csv", csv_str) + + +if __name__ == "__main__": + export_public_data() diff --git a/src/mavedb/scripts/populate_mapped_variants.py b/src/mavedb/scripts/populate_mapped_variants.py new file mode 100644 index 000000000..8df46f3dd --- /dev/null +++ b/src/mavedb/scripts/populate_mapped_variants.py @@ -0,0 +1,173 @@ +import logging +import click +from datetime import date +from typing import Sequence, Optional + +from sqlalchemy import cast, select +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.orm import Session + +from mavedb.data_providers.services import vrs_mapper +from mavedb.lib.logging.context import format_raised_exception_info_as_dict +from mavedb.models.enums.mapping_state import MappingState +from mavedb.models.score_set import ScoreSet +from mavedb.models.mapped_variant import MappedVariant +from mavedb.models.target_gene import TargetGene +from mavedb.models.variant import Variant + +from mavedb.scripts.environment import script_environment, with_database_session + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +def variant_from_mapping(db: Session, mapping: dict, dcd_mapping_version: str) -> MappedVariant: + variant_urn = mapping.get("mavedb_id") + variant = db.scalars(select(Variant).where(Variant.urn == variant_urn)).one() + + return MappedVariant( + variant_id=variant.id, + pre_mapped=mapping.get("pre_mapped"), + post_mapped=mapping.get("post_mapped"), + modification_date=date.today(), + mapped_date=date.today(), # since this is a one-time script, assume mapping was done today + vrs_version=mapping.get("vrs_version"), + mapping_api_version=dcd_mapping_version, + error_message=mapping.get("error_message"), + current=True, + ) + + +@script_environment.command() +@with_database_session +@click.argument("urns", nargs=-1) +@click.option("--all", help="Populate mapped variants for every score set in MaveDB.", is_flag=True) +def populate_mapped_variant_data(db: Session, urns: Sequence[Optional[str]], all: bool): + score_set_ids: Sequence[Optional[int]] + if all: + score_set_ids = db.scalars(select(ScoreSet.id)).all() + logger.info( + f"Command invoked with --all. Routine will populate mapped variant data for {len(urns)} score sets." + ) + else: + score_set_ids = db.scalars(select(ScoreSet.id).where(ScoreSet.urn.in_(urns))).all() + logger.info(f"Populating mapped variant data for the provided score sets ({len(urns)}).") + + vrs = vrs_mapper() + + for idx, ss_id in enumerate(score_set_ids): + if not ss_id: + continue + + score_set = db.scalar(select(ScoreSet).where(ScoreSet.id == ss_id)) + if not score_set: + logger.warning(f"Could not fetch score set with id={ss_id}.") + continue + + try: + existing_mapped_variants = ( + db.query(MappedVariant).join(Variant).join(ScoreSet).filter(MappedVariant.current.is_(True)).all() + ) + + for variant in existing_mapped_variants: + variant.current = False + + assert score_set.urn + logger.info(f"Mapping score set {score_set.urn}.") + mapped_scoreset = vrs.map_score_set(score_set.urn) + logger.info(f"Done mapping score set {score_set.urn}.") + + dcd_mapping_version = mapped_scoreset["dcd_mapping_version"] + mapped_scores = mapped_scoreset.get("mapped_scores") + + if not mapped_scores: + # if there are no mapped scores, the score set failed to map. + score_set.mapping_state = MappingState.failed + score_set.mapping_errors = {"error_message": mapped_scoreset.get("error_message")} + db.commit() + logger.info(f"No mapped variants available for {score_set.urn}.") + else: + computed_genomic_ref = mapped_scoreset.get("computed_genomic_reference_sequence") + mapped_genomic_ref = mapped_scoreset.get("mapped_genomic_reference_sequence") + computed_protein_ref = mapped_scoreset.get("computed_protein_reference_sequence") + mapped_protein_ref = mapped_scoreset.get("mapped_protein_reference_sequence") + + # assumes one target gene per score set, which is currently true in mavedb as of sept. 2024. + target_gene = db.scalars( + select(TargetGene) + .join(ScoreSet) + .where( + ScoreSet.urn == str(score_set.urn), + ) + ).one() + + excluded_pre_mapped_keys = {"sequence"} + if computed_genomic_ref and mapped_genomic_ref: + pre_mapped_metadata = computed_genomic_ref + target_gene.pre_mapped_metadata = cast( + { + "genomic": { + k: pre_mapped_metadata[k] + for k in set(list(pre_mapped_metadata.keys())) - excluded_pre_mapped_keys + } + }, + JSONB, + ) + target_gene.post_mapped_metadata = cast({"genomic": mapped_genomic_ref}, JSONB) + elif computed_protein_ref and mapped_protein_ref: + pre_mapped_metadata = computed_protein_ref + target_gene.pre_mapped_metadata = cast( + { + "protein": { + k: pre_mapped_metadata[k] + for k in set(list(pre_mapped_metadata.keys())) - excluded_pre_mapped_keys + } + }, + JSONB, + ) + target_gene.post_mapped_metadata = cast({"protein": mapped_protein_ref}, JSONB) + else: + raise ValueError(f"incomplete or inconsistent metadata for score set {score_set.urn}") + + mapped_variants = [ + variant_from_mapping(db=db, mapping=mapped_score, dcd_mapping_version=dcd_mapping_version) + for mapped_score in mapped_scores + ] + logger.debug(f"Done constructing {len(mapped_variants)} mapped variant objects.") + + num_successful_variants = len( + [variant for variant in mapped_variants if variant.post_mapped is not None] + ) + logger.debug( + f"{num_successful_variants}/{len(mapped_variants)} variants generated a post-mapped VRS object." + ) + + if num_successful_variants == 0: + score_set.mapping_state = MappingState.failed + score_set.mapping_errors = {"error_message": "All variants failed to map"} + elif num_successful_variants < len(mapped_variants): + score_set.mapping_state = MappingState.incomplete + else: + score_set.mapping_state = MappingState.complete + + db.bulk_save_objects(mapped_variants) + db.commit() + logger.info(f"Done populating {len(mapped_variants)} mapped variants for {score_set.urn}.") + + except Exception as e: + logging_context = { + "mapped_score_sets": urns[:idx], + "unmapped_score_sets": urns[idx:], + } + logging_context = {**logging_context, **format_raised_exception_info_as_dict(e)} + logger.error(f"Score set {score_set.urn} failed to map.", extra=logging_context) + logger.info(f"Rolling back all changes for scoreset {score_set.urn}") + db.rollback() + + logger.info(f"Done with score set {score_set.urn}. ({idx+1}/{len(urns)}).") + + logger.info("Done populating mapped variant data.") + + +if __name__ == "__main__": + populate_mapped_variant_data()