Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions alembic/manual_migrations/refresh_published_tmp_urns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import sqlalchemy as sa
from sqlalchemy.orm import Session, configure_mappers

from mavedb.models import *

from mavedb.lib.score_sets import refresh_variant_urns

from mavedb.models.score_set import ScoreSet
from mavedb.models.variant import Variant

from mavedb.db.session import SessionLocal

configure_mappers()


def do_migration(db: Session):
published_score_sets_with_associated_tmp_variants: sa.ScalarResult[str]
published_score_sets_with_associated_tmp_variants = db.execute(
sa.select(sa.distinct(ScoreSet.urn)).join(Variant).where(ScoreSet.published_date.is_not(None), Variant.urn.like("%tmp:%"))
).scalars()

for score_set_urn in published_score_sets_with_associated_tmp_variants:
refresh_variant_urns(db, db.execute(sa.select(ScoreSet).where(ScoreSet.urn == score_set_urn)).scalar_one())


if __name__ == "__main__":
db = SessionLocal()
db.current_user = None # type: ignore

do_migration(db)

db.commit()
db.close()
15 changes: 15 additions & 0 deletions src/mavedb/lib/score_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,21 @@ def create_variants(db, score_set: ScoreSet, variants_data: list[VariantData], b
return len(score_set.variants)


def refresh_variant_urns(db: Session, score_set: ScoreSet):
variants = db.execute(select(Variant).where(Variant.score_set_id == score_set.id)).scalars()

for variant in variants:
if not variant.urn:
raise ValueError("All variants should have an associated URN.")

variant_number = variant.urn.split("#")[1]
refreshed_urn = f"{score_set.urn}#{variant_number}"
variant.urn = refreshed_urn
db.add(variant)

db.commit()


def bulk_create_urns(n, score_set, reset_counter=False) -> list[str]:
start_value = 0 if reset_counter else score_set.num_variants
parent_urn = score_set.urn
Expand Down
2 changes: 2 additions & 0 deletions src/mavedb/routers/score_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
)
from mavedb.lib.score_sets import (
search_score_sets as _search_score_sets,
refresh_variant_urns,
)
from mavedb.lib.taxonomies import find_or_create_taxonomy
from mavedb.lib.urns import (
Expand Down Expand Up @@ -1034,6 +1035,7 @@ def publish_score_set(
item.urn = generate_score_set_urn(db, item.experiment)
item.private = False
item.published_date = published_date
refresh_variant_urns(db, item)

save_to_logging_context({"score_set": item.urn})

Expand Down
25 changes: 25 additions & 0 deletions tests/routers/test_score_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@

import jsonschema
from arq import ArqRedis
from sqlalchemy import select

from mavedb.lib.validation.urn_re import MAVEDB_TMP_URN_RE
from mavedb.models.enums.processing_state import ProcessingState
from mavedb.models.experiment import Experiment as ExperimentDbModel
from mavedb.models.score_set import ScoreSet as ScoreSetDbModel
from mavedb.models.variant import Variant as VariantDbModel
from mavedb.view_models.orcid import OrcidUser
from mavedb.view_models.score_set import ScoreSet, ScoreSetCreate
from tests.helpers.constants import (
Expand Down Expand Up @@ -593,6 +595,11 @@ def test_publish_score_set(session, data_provider, client, setup_router_db, data
for key in expected_response:
assert (key, expected_response[key]) == (key, score_set[key])

score_set_variants = session.execute(
select(VariantDbModel).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set["urn"])
).scalars()
assert all([variant.urn.startswith("urn:mavedb:") for variant in score_set_variants])


def test_publish_multiple_score_sets(session, data_provider, client, setup_router_db, data_files):
experiment = create_experiment(client)
Expand Down Expand Up @@ -625,6 +632,19 @@ def test_publish_multiple_score_sets(session, data_provider, client, setup_route
assert pub_score_set_3_data["title"] == score_set_3["title"]
assert pub_score_set_3_data["experiment"]["urn"] == "urn:mavedb:00000001-a"

score_set_1_variants = session.execute(
select(VariantDbModel).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_1["urn"])
).scalars()
assert all([variant.urn.startswith("urn:mavedb:") for variant in score_set_1_variants])
score_set_2_variants = session.execute(
select(VariantDbModel).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_2["urn"])
).scalars()
assert all([variant.urn.startswith("urn:mavedb:") for variant in score_set_2_variants])
score_set_3_variants = session.execute(
select(VariantDbModel).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_3["urn"])
).scalars()
assert all([variant.urn.startswith("urn:mavedb:") for variant in score_set_3_variants])


def test_cannot_publish_score_set_without_variants(client, setup_router_db):
experiment = create_experiment(client)
Expand Down Expand Up @@ -727,6 +747,11 @@ def test_contributor_can_publish_other_users_score_set(session, data_provider, c
for key in expected_response:
assert (key, expected_response[key]) == (key, score_set[key])

score_set_variants = session.execute(
select(VariantDbModel).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set["urn"])
).scalars()
assert all([variant.urn.startswith("urn:mavedb:") for variant in score_set_variants])


def test_admin_cannot_publish_other_user_private_score_set(
session, data_provider, client, admin_app_overrides, setup_router_db, data_files
Expand Down