From 8bd47fd6170afcaed852920ec36759e9a375bb06 Mon Sep 17 00:00:00 2001 From: pkarmarkar Date: Fri, 2 Jan 2026 13:15:49 -0800 Subject: [PATCH] feat: Add DatabaseSessionService with JDBC and Flyway support fixes #665 Summary Implements a production-ready database-backed session service that provides persistent storage for ADK sessions, events, and state using JDBC. Key Features JDBC with HikariCP connection pooling for optimal performance Flyway migrations for schema versioning and zero-downtime deployments Multi-database support: PostgreSQL, MySQL, H2, Cloud Spanner, and other RDBMS Thread-safe operations with pessimistic locking for concurrent updates Comprehensive test coverage with H2 in-memory database and Integration tests for PostgreSQL , MySQL and Spanner(using emulator). Dialect-aware JSON storage (JSONB for PostgreSQL, CLOB for others) Event filtering and pagination for efficient data retrieval Architecture Located in contrib/database-session-service module Minimal core dependencies footprint Users explicitly opt-in via dependency Follows existing contrib pattern Tasks [] Implement DatabaseSessionService with JDBC (no ORM dependencies) [] Add multi-database support (PostgreSQL, MySQL, H2, Spanner) [] Implement 3-tier state storage (app/user/session levels) [] Add Flyway migrations for schema management [] Add comprehensive test suite (unit + integration tests) [] Add documentation and usage examples [] Address code review feedback --- contrib/database-session-service/README.md | 87 ++ contrib/database-session-service/pom.xml | 186 +++ .../adk/sessions/DatabaseSessionService.java | 1180 +++++++++++++++++ .../com/google/adk/sessions/dao/EventDao.java | 141 ++ .../google/adk/sessions/dao/SessionDao.java | 131 ++ .../com/google/adk/sessions/dao/StateDao.java | 110 ++ .../adk/sessions/dialect/DialectDetector.java | 45 + .../adk/sessions/dialect/H2Dialect.java | 33 + .../adk/sessions/dialect/MySqlDialect.java | 33 + .../adk/sessions/dialect/PostgresDialect.java | 35 + .../adk/sessions/dialect/SpannerDialect.java | 65 + .../adk/sessions/dialect/SqlDialect.java | 56 + .../adk/sessions/model/AppStateRow.java | 33 + .../google/adk/sessions/model/EventRow.java | 69 + .../google/adk/sessions/model/SessionRow.java | 60 + .../adk/sessions/model/UserStateRow.java | 42 + .../adk/sessions/util/JdbcTemplate.java | 101 ++ .../sessions/util/NamedParameterSupport.java | 58 + .../google/adk/sessions/util/RowMapper.java | 9 + .../db/migration/h2/V1__Initial_schema.sql | 66 + .../db/migration/mysql/V1__Initial_schema.sql | 66 + .../postgresql/V1__Initial_schema.sql | 66 + .../migration/spanner/V1__Initial_schema.sql | 59 + .../adk/sessions/AppUserStateLockingTest.java | 298 +++++ .../AppendEventRaceConditionTest.java | 300 +++++ .../google/adk/sessions/AppendEventTest.java | 259 ++++ .../ConcurrentSessionOperationsTest.java | 498 +++++++ .../sessions/ContentSerializationTest.java | 522 ++++++++ .../CreateSessionInTransactionTest.java | 246 ++++ .../sessions/DatabaseSessionServiceTest.java | 693 ++++++++++ .../google/adk/sessions/DiagnosticTest.java | 112 ++ .../adk/sessions/DialectDetectorTest.java | 119 ++ .../adk/sessions/EventFilteringTest.java | 297 +++++ .../adk/sessions/FlywayMigrationTest.java | 394 ++++++ .../google/adk/sessions/HikariConfigTest.java | 216 +++ .../adk/sessions/ListSessionsEventsTest.java | 89 ++ .../sessions/MySQLAgentIntegrationTest.java | 355 +++++ .../adk/sessions/MySQLFunctionalTest.java | 184 +++ .../adk/sessions/MySQLIntegrationTest.java | 809 +++++++++++ .../adk/sessions/NegativeTestCases.java | 410 ++++++ .../adk/sessions/PessimisticLockingTest.java | 340 +++++ .../PostgreSQLAgentIntegrationTest.java | 364 +++++ .../sessions/PostgreSQLFunctionalTest.java | 184 +++ .../sessions/PostgreSQLIntegrationTest.java | 807 +++++++++++ .../sessions/ReadTwiceNonDestructiveTest.java | 267 ++++ .../adk/sessions/SessionUpdateTimeTest.java | 361 +++++ .../adk/sessions/SpannerFunctionalTest.java | 184 +++ .../adk/sessions/SpannerIntegrationTest.java | 808 +++++++++++ .../google/adk/sessions/StateDeltaTest.java | 400 ++++++ .../adk/sessions/StateManagementTest.java | 392 ++++++ .../adk/sessions/StatePrefixHandlingTest.java | 334 +++++ .../adk/testing/TestDatabaseConfig.java | 125 ++ .../java/com/google/adk/testing/TestLlm.java | 310 +++++ pom.xml | 31 +- 54 files changed, 13438 insertions(+), 1 deletion(-) create mode 100644 contrib/database-session-service/README.md create mode 100644 contrib/database-session-service/pom.xml create mode 100644 contrib/database-session-service/src/main/java/com/google/adk/sessions/DatabaseSessionService.java create mode 100644 contrib/database-session-service/src/main/java/com/google/adk/sessions/dao/EventDao.java create mode 100644 contrib/database-session-service/src/main/java/com/google/adk/sessions/dao/SessionDao.java create mode 100644 contrib/database-session-service/src/main/java/com/google/adk/sessions/dao/StateDao.java create mode 100644 contrib/database-session-service/src/main/java/com/google/adk/sessions/dialect/DialectDetector.java create mode 100644 contrib/database-session-service/src/main/java/com/google/adk/sessions/dialect/H2Dialect.java create mode 100644 contrib/database-session-service/src/main/java/com/google/adk/sessions/dialect/MySqlDialect.java create mode 100644 contrib/database-session-service/src/main/java/com/google/adk/sessions/dialect/PostgresDialect.java create mode 100644 contrib/database-session-service/src/main/java/com/google/adk/sessions/dialect/SpannerDialect.java create mode 100644 contrib/database-session-service/src/main/java/com/google/adk/sessions/dialect/SqlDialect.java create mode 100644 contrib/database-session-service/src/main/java/com/google/adk/sessions/model/AppStateRow.java create mode 100644 contrib/database-session-service/src/main/java/com/google/adk/sessions/model/EventRow.java create mode 100644 contrib/database-session-service/src/main/java/com/google/adk/sessions/model/SessionRow.java create mode 100644 contrib/database-session-service/src/main/java/com/google/adk/sessions/model/UserStateRow.java create mode 100644 contrib/database-session-service/src/main/java/com/google/adk/sessions/util/JdbcTemplate.java create mode 100644 contrib/database-session-service/src/main/java/com/google/adk/sessions/util/NamedParameterSupport.java create mode 100644 contrib/database-session-service/src/main/java/com/google/adk/sessions/util/RowMapper.java create mode 100644 contrib/database-session-service/src/main/resources/db/migration/h2/V1__Initial_schema.sql create mode 100644 contrib/database-session-service/src/main/resources/db/migration/mysql/V1__Initial_schema.sql create mode 100644 contrib/database-session-service/src/main/resources/db/migration/postgresql/V1__Initial_schema.sql create mode 100644 contrib/database-session-service/src/main/resources/db/migration/spanner/V1__Initial_schema.sql create mode 100644 contrib/database-session-service/src/test/java/com/google/adk/sessions/AppUserStateLockingTest.java create mode 100644 contrib/database-session-service/src/test/java/com/google/adk/sessions/AppendEventRaceConditionTest.java create mode 100644 contrib/database-session-service/src/test/java/com/google/adk/sessions/AppendEventTest.java create mode 100644 contrib/database-session-service/src/test/java/com/google/adk/sessions/ConcurrentSessionOperationsTest.java create mode 100644 contrib/database-session-service/src/test/java/com/google/adk/sessions/ContentSerializationTest.java create mode 100644 contrib/database-session-service/src/test/java/com/google/adk/sessions/CreateSessionInTransactionTest.java create mode 100644 contrib/database-session-service/src/test/java/com/google/adk/sessions/DatabaseSessionServiceTest.java create mode 100644 contrib/database-session-service/src/test/java/com/google/adk/sessions/DiagnosticTest.java create mode 100644 contrib/database-session-service/src/test/java/com/google/adk/sessions/DialectDetectorTest.java create mode 100644 contrib/database-session-service/src/test/java/com/google/adk/sessions/EventFilteringTest.java create mode 100644 contrib/database-session-service/src/test/java/com/google/adk/sessions/FlywayMigrationTest.java create mode 100644 contrib/database-session-service/src/test/java/com/google/adk/sessions/HikariConfigTest.java create mode 100644 contrib/database-session-service/src/test/java/com/google/adk/sessions/ListSessionsEventsTest.java create mode 100644 contrib/database-session-service/src/test/java/com/google/adk/sessions/MySQLAgentIntegrationTest.java create mode 100644 contrib/database-session-service/src/test/java/com/google/adk/sessions/MySQLFunctionalTest.java create mode 100644 contrib/database-session-service/src/test/java/com/google/adk/sessions/MySQLIntegrationTest.java create mode 100644 contrib/database-session-service/src/test/java/com/google/adk/sessions/NegativeTestCases.java create mode 100644 contrib/database-session-service/src/test/java/com/google/adk/sessions/PessimisticLockingTest.java create mode 100644 contrib/database-session-service/src/test/java/com/google/adk/sessions/PostgreSQLAgentIntegrationTest.java create mode 100644 contrib/database-session-service/src/test/java/com/google/adk/sessions/PostgreSQLFunctionalTest.java create mode 100644 contrib/database-session-service/src/test/java/com/google/adk/sessions/PostgreSQLIntegrationTest.java create mode 100644 contrib/database-session-service/src/test/java/com/google/adk/sessions/ReadTwiceNonDestructiveTest.java create mode 100644 contrib/database-session-service/src/test/java/com/google/adk/sessions/SessionUpdateTimeTest.java create mode 100644 contrib/database-session-service/src/test/java/com/google/adk/sessions/SpannerFunctionalTest.java create mode 100644 contrib/database-session-service/src/test/java/com/google/adk/sessions/SpannerIntegrationTest.java create mode 100644 contrib/database-session-service/src/test/java/com/google/adk/sessions/StateDeltaTest.java create mode 100644 contrib/database-session-service/src/test/java/com/google/adk/sessions/StateManagementTest.java create mode 100644 contrib/database-session-service/src/test/java/com/google/adk/sessions/StatePrefixHandlingTest.java create mode 100644 contrib/database-session-service/src/test/java/com/google/adk/testing/TestDatabaseConfig.java create mode 100644 contrib/database-session-service/src/test/java/com/google/adk/testing/TestLlm.java diff --git a/contrib/database-session-service/README.md b/contrib/database-session-service/README.md new file mode 100644 index 000000000..90b5d5415 --- /dev/null +++ b/contrib/database-session-service/README.md @@ -0,0 +1,87 @@ +# Database Session Service + +JDBC-based session service implementation for ADK Java. + +## Features + +- **No ORM Dependencies**: Uses JDBC with HikariCP for connection pooling +- **Multi-Database Support**: PostgreSQL, MySQL, H2 (SQLite not supported) +- **Automatic Schema Management**: Flyway migrations handle table creation/updates +- **3-Tier State Storage**: Separate tables for app-level, user-level, and session-level state +- **Reactive API**: RxJava 3 Single/Maybe/Completable return types + +## Dependencies + +- **HikariCP**: High-performance JDBC connection pool +- **Flyway**: Database schema versioning and migration +- **Jackson**: JSON serialization for events and state +- **RxJava 3**: Reactive programming support + +## Database Schema + +The service creates and manages these tables: + +- `app_states`: Application-level state (shared across all users) +- `user_states`: User-level state (shared across user's sessions) +- `sessions`: Individual session data +- `events`: Event history for each session + +## Usage + +```java +// Create service with database URL +String dbUrl = "jdbc:postgresql://localhost:5432/adk?user=postgres&password=secret"; +try (DatabaseSessionService sessionService = new DatabaseSessionService(dbUrl)) { + + // Create a session + Session session = sessionService.createSession( + "myApp", + "user123", + new ConcurrentHashMap<>(), + null + ).blockingGet(); + + // Append an event + Event event = Event.builder() + .id(UUID.randomUUID().toString()) + .invocationId("inv-1") + .timestamp(System.currentTimeMillis()) + .build(); + + Event appendedEvent = sessionService.appendEvent(session, event).blockingGet(); +} +``` + +## Supported Databases + +- **PostgreSQL**: Full support with JSONB + - URL: `jdbc:postgresql://host:port/database?user=...&password=...` +- **MySQL**: Full support with JSON + - URL: `jdbc:mysql://host:port/database?user=...&password=...` +- **H2**: For testing and development + - URL: `jdbc:h2:mem:testdb` or `jdbc:h2:file:./data/mydb` +- **Cloud Spanner**: Full support + - URL: `jdbc:cloudspanner:/projects/PROJECT_ID/instances/INSTANCE_ID/databases/DATABASE_ID` +- **SQLite**: NOT supported (no UPSERT support) + +## State Management + +State is stored across three tables with merge priority: + +1. **App State** (lowest priority): `app:key` prefix +2. **User State** (medium priority): `user:key` prefix +3. **Session State** (highest priority): No prefix + +When retrieving a session, states are merged: app → user → session (higher priority overwrites). + +## Configuration + +Optional properties can be passed to the constructor: + +```java +Map props = new HashMap<>(); +props.put("connectionTimeout", 30000); +props.put("maximumPoolSize", 10); + +DatabaseSessionService service = new DatabaseSessionService(dbUrl, props); +``` diff --git a/contrib/database-session-service/pom.xml b/contrib/database-session-service/pom.xml new file mode 100644 index 000000000..bfab7cd22 --- /dev/null +++ b/contrib/database-session-service/pom.xml @@ -0,0 +1,186 @@ + + + + 4.0.0 + + + com.google.adk + google-adk-parent + 0.5.1-SNAPSHOT + ../../pom.xml + + + google-adk-database-session-service + Agent Development Kit - Database Session Service + Database integration with Agent Development Kit for User Session Management + + + + + com.google.adk + google-adk + ${project.version} + + + + + + + com.zaxxer + HikariCP + + + + + com.fasterxml.jackson.core + jackson-databind + + + + + io.reactivex.rxjava3 + rxjava + + + + + org.slf4j + slf4j-api + + + + + + + org.flywaydb + flyway-core + + + + + org.flywaydb + flyway-database-postgresql + runtime + + + + + org.flywaydb + flyway-mysql + runtime + + + + + org.flywaydb + flyway-gcp-spanner + runtime + + + + + + + org.postgresql + postgresql + true + + + + + com.mysql + mysql-connector-j + true + + + + + com.google.cloud + google-cloud-spanner-jdbc + true + + + + + + + com.h2database + h2 + test + + + + org.junit.jupiter + junit-jupiter-api + test + + + + org.junit.jupiter + junit-jupiter-params + test + + + + org.junit.jupiter + junit-jupiter-engine + test + + + + org.slf4j + slf4j-simple + test + + + + com.google.truth + truth + test + + + + org.mockito + mockito-core + test + + + + + + + src/main/resources + true + + + + + maven-compiler-plugin + + + org.jacoco + jacoco-maven-plugin + + + org.apache.maven.plugins + maven-surefire-plugin + + + + diff --git a/contrib/database-session-service/src/main/java/com/google/adk/sessions/DatabaseSessionService.java b/contrib/database-session-service/src/main/java/com/google/adk/sessions/DatabaseSessionService.java new file mode 100644 index 000000000..d47001273 --- /dev/null +++ b/contrib/database-session-service/src/main/java/com/google/adk/sessions/DatabaseSessionService.java @@ -0,0 +1,1180 @@ +package com.google.adk.sessions; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.adk.events.Event; +import com.google.adk.sessions.dao.EventDao; +import com.google.adk.sessions.dao.SessionDao; +import com.google.adk.sessions.dao.StateDao; +import com.google.adk.sessions.dialect.DialectDetector; +import com.google.adk.sessions.dialect.SqlDialect; +import com.google.adk.sessions.model.AppStateRow; +import com.google.adk.sessions.model.EventRow; +import com.google.adk.sessions.model.SessionRow; +import com.google.adk.sessions.model.UserStateRow; +import com.google.adk.sessions.util.JdbcTemplate; +import com.zaxxer.hikari.HikariConfig; +import com.zaxxer.hikari.HikariDataSource; +import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.core.Single; +import io.reactivex.rxjava3.schedulers.Schedulers; +import java.time.Instant; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.stream.Collectors; +import org.flywaydb.core.Flyway; +import org.flywaydb.core.api.FlywayException; +import org.flywaydb.core.api.output.MigrateResult; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * JDBC-based implementation of {@link BaseSessionService}. + * + *

This service provides persistent session management using JDBC and HikariCP connection + * pooling. It supports multiple databases with automatic dialect detection and schema management + * via Flyway migrations. + * + *

Features: + * + *

+ * + *

Supported Databases: + * + *

+ * + *

State Management: + * + *

This service implements a 3-tier state storage model: + * + *

+ * + *

State is merged with priority: App → User → Session (higher priority overwrites lower). + * + *

Thread Safety: + * + *

This class is thread-safe. All database operations use connection pooling and database + * transactions. Concurrent operations on the same session are serialized via pessimistic locking + * (SELECT ... FOR UPDATE). + * + *

Resource Management: + * + *

This class implements {@link AutoCloseable}. Always use try-with-resources or explicitly call + * {@link #close()} to release database connections: + * + *

{@code
+ * try (DatabaseSessionService service = new DatabaseSessionService(jdbcUrl)) {
+ *     // Use the service
+ * } // Connections automatically closed
+ * }
+ * + *

Example Usage: + * + *

{@code
+ * // Create service with PostgreSQL
+ * String jdbcUrl = "jdbc:postgresql://localhost:5432/adk?user=postgres&password=secret";
+ * try (DatabaseSessionService service = new DatabaseSessionService(jdbcUrl)) {
+ *
+ *     // Create a session with initial state
+ *     ConcurrentMap state = new ConcurrentHashMap<>();
+ *     state.put("app:version", "1.0");      // App-level state
+ *     state.put("user:theme", "dark");       // User-level state
+ *     state.put("currentStep", 1);           // Session-level state
+ *
+ *     Session session = service.createSession("myApp", "user123", state, null).blockingGet();
+ *
+ *     // Append an event with state delta
+ *     Event event = Event.builder()
+ *         .id(UUID.randomUUID().toString())
+ *         .invocationId("inv-1")
+ *         .timestamp(System.currentTimeMillis())
+ *         .actions(EventActions.builder()
+ *             .stateDelta(Map.of("currentStep", 2))
+ *             .build())
+ *         .build();
+ *
+ *     Event appendedEvent = service.appendEvent(session, event).blockingGet();
+ *
+ *     // Retrieve updated session
+ *     Session updated = service.getSession("myApp", "user123", session.id(), Optional.empty())
+ *         .blockingGet();
+ * }
+ * }
+ * + * @see BaseSessionService + * @see Session + * @see Event + * @see State + */ +public class DatabaseSessionService implements BaseSessionService, AutoCloseable { + + private static final Logger logger = LoggerFactory.getLogger(DatabaseSessionService.class); + + private final HikariDataSource dataSource; + private final JdbcTemplate jdbcTemplate; + private final SqlDialect dialect; + private final ObjectMapper objectMapper; + private final SessionDao sessionDao; + private final EventDao eventDao; + private final StateDao stateDao; + private final AtomicBoolean closed = new AtomicBoolean(false); + + /** + * Creates a new DatabaseSessionService with default configuration. + * + *

This constructor uses default HikariCP connection pool settings: max pool size = 10, min + * idle = 2, connection timeout = 30s. + * + *

The database dialect is automatically detected from the JDBC URL, and Flyway migrations are + * run automatically to create/update the schema. + * + * @param jdbcUrl the JDBC connection URL (e.g., {@code + * jdbc:postgresql://localhost:5432/adk?user=postgres&password=secret}) + * @throws NullPointerException if jdbcUrl is null + * @throws IllegalArgumentException if the database dialect cannot be detected from the JDBC URL + * @throws SessionException if database migration fails + * @see #DatabaseSessionService(String, Map) + */ + public DatabaseSessionService(String jdbcUrl) { + this(jdbcUrl, Collections.emptyMap()); + } + + /** + * Creates a new DatabaseSessionService with custom configuration. + * + *

This constructor allows customization of both HikariCP connection pool settings and + * database-specific properties. + * + *

Example with custom connection pool settings: + * + *

{@code
+   * Map props = new HashMap<>();
+   * props.put("hikari.connectionTimeout", 60000);  // 60 seconds
+   * props.put("hikari.maximumPoolSize", 20);       // 20 connections
+   * DatabaseSessionService service = new DatabaseSessionService(jdbcUrl, props);
+   * }
+ * + *

Supported HikariCP properties (prefix with "hikari."): + * + *

+ * + *

All properties without the "hikari." prefix are passed to the underlying DataSource. + * + * @param jdbcUrl the JDBC connection URL + * @param properties configuration properties for HikariCP and the DataSource + * @throws NullPointerException if jdbcUrl is null + * @throws IllegalArgumentException if the database dialect cannot be detected from the JDBC URL + * @throws SessionException if database migration fails + */ + public DatabaseSessionService(String jdbcUrl, Map properties) { + Objects.requireNonNull(jdbcUrl, "JDBC URL cannot be null"); + + this.dialect = DialectDetector.detectFromJdbcUrl(jdbcUrl); + logger.info("Detected SQL dialect: {}", dialect.dialectName()); + + runMigrations(jdbcUrl); + + this.dataSource = createDataSource(jdbcUrl, properties); + + this.jdbcTemplate = new JdbcTemplate(dataSource); + this.objectMapper = com.google.adk.JsonBaseModel.getMapper(); + this.sessionDao = new SessionDao(dialect); + this.eventDao = new EventDao(dialect); + this.stateDao = new StateDao(dialect); + + logger.info( + "DatabaseSessionService initialized with {} (JDBC implementation)", dialect.dialectName()); + } + + private void runMigrations(String jdbcUrl) { + try { + String dialectFolder = extractDialectFolder(dialect.dialectName()); + String flywayLocation = "classpath:db/migration/" + dialectFolder; + + logger.info("Starting Flyway database migration"); + logger.info("Dialect: {}", dialect.dialectName()); + logger.info("Migration location: {}", flywayLocation); + logger.info("JDBC URL: {}", jdbcUrl.replaceAll("password=[^&;]*", "password=***")); + + String baselineOnMigrateStr = + System.getProperty( + "FLYWAY_BASELINE_ON_MIGRATE", + System.getenv().getOrDefault("FLYWAY_BASELINE_ON_MIGRATE", "false")); + boolean baselineOnMigrate = Boolean.parseBoolean(baselineOnMigrateStr); + + String lockRetryCountStr = + System.getProperty( + "FLYWAY_LOCK_RETRY_COUNT", + System.getenv().getOrDefault("FLYWAY_LOCK_RETRY_COUNT", "120")); + int lockRetryCount = Integer.parseInt(lockRetryCountStr); + + logger.info( + "Flyway configuration: baselineOnMigrate={}, lockRetryCount={}", + baselineOnMigrate, + lockRetryCount); + + Flyway flyway = + Flyway.configure() + .dataSource(jdbcUrl, null, null) + .locations(flywayLocation) + .cleanDisabled(true) + .lockRetryCount(lockRetryCount) + .baselineOnMigrate(baselineOnMigrate) + .load(); + + MigrateResult result = flyway.migrate(); + + if (result.migrationsExecuted > 0) { + logger.info( + "Flyway migration completed: {} migration(s) applied successfully", + result.migrationsExecuted); + } else { + logger.info("Database schema is up to date (no migrations applied)"); + } + logger.info("Flyway migration complete"); + } catch (FlywayException e) { + throw new SessionException("Failed to run database migrations", e); + } + } + + private String extractDialectFolder(String dialectName) { + String lower = dialectName.toLowerCase(); + if (lower.contains("postgres")) return "postgresql"; + if (lower.contains("mysql")) return "mysql"; + if (lower.contains("h2")) return "h2"; + if (lower.contains("spanner")) return "spanner"; + throw new IllegalArgumentException("Unsupported dialect: " + dialectName); + } + + private HikariDataSource createDataSource(String jdbcUrl, Map properties) { + HikariConfig config = new HikariConfig(); + config.setJdbcUrl(jdbcUrl); + + int maxPoolSize = getIntProperty(properties, "hikari.maximumPoolSize", 10); + int minIdle = getIntProperty(properties, "hikari.minimumIdle", 2); + long connTimeout = getLongProperty(properties, "hikari.connectionTimeout", 30000L); + long idleTimeout = getLongProperty(properties, "hikari.idleTimeout", 600000L); + long maxLifetime = getLongProperty(properties, "hikari.maxLifetime", 1800000L); + + config.setMaximumPoolSize(maxPoolSize); + config.setMinimumIdle(minIdle); + config.setConnectionTimeout(connTimeout); + config.setIdleTimeout(idleTimeout); + config.setMaxLifetime(maxLifetime); + + properties.entrySet().stream() + .filter(e -> !e.getKey().startsWith("hikari.")) + .forEach(e -> config.addDataSourceProperty(e.getKey(), e.getValue())); + + logger.debug("Initializing HikariCP connection pool"); + logger.debug( + "Pool configuration: maxPoolSize={}, minIdle={}, connectionTimeout={}ms, idleTimeout={}ms, maxLifetime={}ms", + maxPoolSize, + minIdle, + connTimeout, + idleTimeout, + maxLifetime); + + HikariDataSource dataSource = new HikariDataSource(config); + logger.debug("HikariCP connection pool created successfully"); + return dataSource; + } + + /** + * Creates a new session with the specified parameters. + * + *

This method creates a new session and initializes the 3-tier state storage (app, user, + * session). If the provided state map contains keys with prefixes {@code app:} or {@code user:}, + * those entries are stored in the corresponding state tables. + * + *

State Handling: + * + *

+ * + *

If app or user state already exists, it is updated with the new values. + * + * @param appName the application name (must not be null) + * @param userId the user identifier (must not be null) + * @param state the initial state map, can be null or empty + * @param sessionId optional session ID; if null or empty, a UUID is generated + * @return a Single that emits the created Session + * @throws NullPointerException if appName or userId is null (checked by BaseSessionService) + * @see State#APP_PREFIX + * @see State#USER_PREFIX + * @see State#TEMP_PREFIX + */ + @Override + public Single createSession( + String appName, String userId, ConcurrentMap state, String sessionId) { + checkNotClosed(); + Objects.requireNonNull(appName, "appName cannot be null"); + Objects.requireNonNull(userId, "userId cannot be null"); + + return Single.fromCallable(() -> createSessionInTransaction(appName, userId, state, sessionId)) + .subscribeOn(Schedulers.io()); + } + + /** + * Retrieves a session by its identifiers. + * + *

This method fetches the session from the database and merges the 3-tier state (app → user → + * session) before returning. The returned session includes all events up to the specified limit. + * + *

Event Filtering: + * + *

+ * + * @param appName the application name (must not be null) + * @param userId the user identifier (must not be null) + * @param sessionId the session identifier (must not be null) + * @param config optional configuration for event filtering + * @return a Maybe that emits the Session if found, or completes empty if not found + * @throws NullPointerException if appName, userId, or sessionId is null + * @see GetSessionConfig + */ + @Override + public Maybe getSession( + String appName, String userId, String sessionId, Optional config) { + checkNotClosed(); + Objects.requireNonNull(appName, "appName cannot be null"); + Objects.requireNonNull(userId, "userId cannot be null"); + Objects.requireNonNull(sessionId, "sessionId cannot be null"); + Objects.requireNonNull(config, "config cannot be null"); + + return Maybe.fromCallable( + () -> + jdbcTemplate.inTransaction( + ops -> { + Optional sessionOpt = + sessionDao.findSession(ops, appName, userId, sessionId); + + if (!sessionOpt.isPresent()) { + return null; + } + + return buildSessionFromRow(ops, sessionOpt.get(), config); + })) + .subscribeOn(Schedulers.io()); + } + + /** + * Lists all sessions for a specific application and user. + * + *

The sessions are returned without events and without merged app/user state (state maps will + * be empty). Use {@link #getSession} to retrieve full session details. + * + *

Sessions are ordered by update_time descending (most recently updated first), with a limit + * of 1000 sessions. + * + * @param appName the application name (must not be null) + * @param userId the user identifier (must not be null) + * @return a Single that emits a ListSessionsResponse containing the sessions + * @throws NullPointerException if appName or userId is null + */ + @Override + public Single listSessions(String appName, String userId) { + checkNotClosed(); + Objects.requireNonNull(appName, "appName cannot be null"); + Objects.requireNonNull(userId, "userId cannot be null"); + + return Single.fromCallable( + () -> { + return jdbcTemplate.inTransaction( + ops -> { + List sessionRows = sessionDao.listSessions(ops, appName, userId); + + List sessions = + sessionRows.stream() + .map( + row -> + toDomainSession( + row, + Collections.emptyList(), + new ConcurrentHashMap<>(), + new ConcurrentHashMap<>())) + .collect(Collectors.toList()); + + return ListSessionsResponse.builder().sessions(sessions).build(); + }); + }) + .subscribeOn(Schedulers.io()); + } + + /** + * Deletes a session and all its associated events. + * + *

This operation cascades to delete all events associated with the session. App-level and + * user-level state are NOT deleted (they may be shared with other sessions). + * + * @param appName the application name (must not be null) + * @param userId the user identifier (must not be null) + * @param sessionId the session identifier (must not be null) + * @return a Completable that completes when the session is deleted + * @throws SessionNotFoundException if the session does not exist + * @throws NullPointerException if any parameter is null + */ + @Override + public Completable deleteSession(String appName, String userId, String sessionId) { + checkNotClosed(); + Objects.requireNonNull(appName, "appName cannot be null"); + Objects.requireNonNull(userId, "userId cannot be null"); + Objects.requireNonNull(sessionId, "sessionId cannot be null"); + + return Completable.fromAction( + () -> { + jdbcTemplate.inTransaction( + ops -> { + // Idempotent delete - no error if session doesn't exist + sessionDao.deleteSession(ops, appName, userId, sessionId); + return null; + }); + }) + .subscribeOn(Schedulers.io()); + } + + /** + * Lists all events for a session. + * + *

This method fetches ALL events for the session in chronological order (oldest first). + * + * @param appName the application name (must not be null) + * @param userId the user identifier (must not be null) + * @param sessionId the session identifier (must not be null) + * @return a Single that emits a ListEventsResponse containing all events + * @throws SessionNotFoundException if the session does not exist + * @throws IllegalStateException if the service has been closed + */ + @Override + public Single listEvents(String appName, String userId, String sessionId) { + checkNotClosed(); + Objects.requireNonNull(appName, "appName cannot be null"); + Objects.requireNonNull(userId, "userId cannot be null"); + Objects.requireNonNull(sessionId, "sessionId cannot be null"); + + return Single.fromCallable( + () -> { + return jdbcTemplate.inTransaction( + ops -> { + Optional sessionOpt = + sessionDao.findSession(ops, appName, userId, sessionId); + + if (!sessionOpt.isPresent()) { + throw new SessionNotFoundException( + "Session not found: " + appName + "/" + userId + "/" + sessionId); + } + + List eventRows = eventDao.listEvents(ops, appName, userId, sessionId); + + List events = + eventRows.stream().map(this::toEvent).collect(Collectors.toList()); + + return ListEventsResponse.builder().events(events).build(); + }); + }) + .subscribeOn(Schedulers.io()); + } + + /** + * Appends an event to a session and persists it to the database. + * + *

This method processes the event's state delta (if present) and applies it to the appropriate + * state tier (app, user, or session). The session's update_time is refreshed even if there is no + * state delta. + * + *

State Delta Processing: + * + *

    + *
  • Keys starting with {@code app:} update app-level state + *
  • Keys starting with {@code user:} update user-level state + *
  • Keys starting with {@code temp:} are ignored (not persisted) + *
  • All other keys update session-level state + *
  • Use {@link State#REMOVED} as a value to delete a state key + *
+ * + *

This operation uses pessimistic locking (SELECT ... FOR UPDATE) to prevent concurrent + * modifications to the same session. + * + * @param session the session to append the event to (must not be null) + * @param event the event to append (must not be null) + * @return a Single that emits the updated Event after processing + * @throws NullPointerException if session or event is null + * @throws SessionNotFoundException if the session does not exist + * @throws IllegalStateException if the service has been closed + * @see State#REMOVED + */ + @Override + public Single appendEvent(Session session, Event event) { + checkNotClosed(); + Objects.requireNonNull(session, "session cannot be null"); + Objects.requireNonNull(event, "event cannot be null"); + Objects.requireNonNull(session.appName(), "session.appName cannot be null"); + Objects.requireNonNull(session.userId(), "session.userId cannot be null"); + Objects.requireNonNull(session.id(), "session.id cannot be null"); + + // DB first, then memory + // If DB fails, transaction rolls back and memory is never updated + return persistEventToDatabase(session.appName(), session.userId(), session.id(), event) + .andThen(BaseSessionService.super.appendEvent(session, event)) + .doOnError( + throwable -> { + logger.error( + "Failed to append event to session {}/{}/{}: {}", + session.appName(), + session.userId(), + session.id(), + throwable.getMessage(), + throwable); + }); + } + + /** + * Persists an event to the database. + * + *

This method handles event persistence and state delta updates. It acquires row-level locks + * on the session, app state, and user state during the transaction. + * + * @param appName the application name (must not be null) + * @param userId the user identifier (must not be null) + * @param sessionId the session identifier (must not be null) + * @param event the event to append (must not be null) + * @return a Completable that completes when the event is persisted + * @throws SessionNotFoundException if the session does not exist + */ + private Completable persistEventToDatabase( + String appName, String userId, String sessionId, Event event) { + return Completable.fromAction( + () -> { + jdbcTemplate.inTransaction( + ops -> { + Optional sessionOpt = + sessionDao.findSessionForUpdate(ops, appName, userId, sessionId); + + if (!sessionOpt.isPresent()) { + throw new SessionNotFoundException( + "Session not found: " + appName + "/" + userId + "/" + sessionId); + } + + SessionRow sessionRow = sessionOpt.get(); + + Optional appStateOpt = stateDao.getAppStateForUpdate(ops, appName); + Map appState = + appStateOpt + .map(s -> fromJson(s.getState())) + .orElse(new ConcurrentHashMap<>()); + + Optional userStateOpt = + stateDao.getUserStateForUpdate(ops, appName, userId); + Map userState = + userStateOpt + .map(s -> fromJson(s.getState())) + .orElse(new ConcurrentHashMap<>()); + + if (event.actions() != null && event.actions().stateDelta() != null) { + ConcurrentMap stateDelta = event.actions().stateDelta(); + + Map appStateDelta = new ConcurrentHashMap<>(); + Map userStateDelta = new ConcurrentHashMap<>(); + Map sessionStateDelta = new ConcurrentHashMap<>(); + + for (Map.Entry entry : stateDelta.entrySet()) { + String key = entry.getKey(); + if (key.startsWith(State.APP_PREFIX)) { + String unprefixedKey = key.substring(State.APP_PREFIX.length()); + appStateDelta.put(unprefixedKey, entry.getValue()); + } else if (key.startsWith(State.USER_PREFIX)) { + String unprefixedKey = key.substring(State.USER_PREFIX.length()); + userStateDelta.put(unprefixedKey, entry.getValue()); + } else if (!key.startsWith(State.TEMP_PREFIX)) { + sessionStateDelta.put(key, entry.getValue()); + } + } + + if (!appStateDelta.isEmpty()) { + for (Map.Entry entry : appStateDelta.entrySet()) { + if (entry.getValue() == State.REMOVED) { + appState.remove(entry.getKey()); + } else { + appState.put(entry.getKey(), entry.getValue()); + } + } + AppStateRow updatedAppState = new AppStateRow(); + updatedAppState.setAppName(appName); + updatedAppState.setState(toJson(appState)); + updatedAppState.setUpdateTime(Instant.now()); + stateDao.upsertAppState(ops, updatedAppState); + } + + if (!userStateDelta.isEmpty()) { + for (Map.Entry entry : userStateDelta.entrySet()) { + if (entry.getValue() == State.REMOVED) { + userState.remove(entry.getKey()); + } else { + userState.put(entry.getKey(), entry.getValue()); + } + } + UserStateRow updatedUserState = new UserStateRow(); + updatedUserState.setAppName(appName); + updatedUserState.setUserId(userId); + updatedUserState.setState(toJson(userState)); + updatedUserState.setUpdateTime(Instant.now()); + stateDao.upsertUserState(ops, updatedUserState); + } + + if (!sessionStateDelta.isEmpty()) { + Map sessionState = fromJson(sessionRow.getState()); + for (Map.Entry entry : sessionStateDelta.entrySet()) { + if (entry.getValue() == State.REMOVED) { + sessionState.remove(entry.getKey()); + } else { + sessionState.put(entry.getKey(), entry.getValue()); + } + } + sessionRow.setState(toJson(sessionState)); + } + } + + sessionRow.setUpdateTime(Instant.now()); + sessionDao.updateSession(ops, sessionRow); + + EventRow eventRow = fromDomainEvent(event, appName, userId, sessionId); + eventDao.insertEvent(ops, eventRow); + + return null; + }); + }) + .subscribeOn(Schedulers.io()); + } + + /** + * Retrieves the app-level state for an application. + * + *

App-level state is shared across all users and sessions for the specified application. This + * is typically used for application-wide configuration or data. + * + * @param appName the application name (must not be null) + * @return a Single that emits the app state map, or null if no app state exists + */ + public Single> getAppState(String appName) { + return Single.fromCallable( + () -> { + return jdbcTemplate.inTransaction( + ops -> { + Optional appStateOpt = stateDao.getAppState(ops, appName); + return appStateOpt + .map(s -> (Map) fromJson(s.getState())) + .orElse(null); + }); + }) + .subscribeOn(Schedulers.io()); + } + + /** + * Sets or replaces the app-level state for an application. + * + *

This operation completely replaces the existing app state. If you need to update specific + * keys, retrieve the current state first, modify it, and then set it back. + * + *

Warning: This affects all users and sessions for the application. + * + * @param appName the application name (must not be null) + * @param state the new app state map (must not be null) + * @return a Completable that completes when the state is updated + */ + public Completable setAppState(String appName, Map state) { + return Completable.fromAction( + () -> { + jdbcTemplate.inTransaction( + ops -> { + AppStateRow row = new AppStateRow(); + row.setAppName(appName); + row.setState(toJson(state)); + row.setUpdateTime(Instant.now()); + + stateDao.upsertAppState(ops, row); + return null; + }); + }) + .subscribeOn(Schedulers.io()); + } + + /** + * Retrieves the user-level state for a specific user in an application. + * + *

User-level state is shared across all sessions for the specified user. This is typically + * used for user preferences or data that should persist across sessions. + * + * @param appName the application name (must not be null) + * @param userId the user identifier (must not be null) + * @return a Single that emits the user state map, or null if no user state exists + */ + public Single> getUserState(String appName, String userId) { + return Single.fromCallable( + () -> { + return jdbcTemplate.inTransaction( + ops -> { + Optional userStateOpt = + stateDao.getUserState(ops, appName, userId); + return userStateOpt + .map(s -> (Map) fromJson(s.getState())) + .orElse(null); + }); + }) + .subscribeOn(Schedulers.io()); + } + + /** + * Sets or replaces the user-level state for a specific user in an application. + * + *

This operation completely replaces the existing user state. If you need to update specific + * keys, retrieve the current state first, modify it, and then set it back. + * + *

Warning: This affects all sessions for the specified user. + * + * @param appName the application name (must not be null) + * @param userId the user identifier (must not be null) + * @param state the new user state map (must not be null) + * @return a Completable that completes when the state is updated + */ + public Completable setUserState(String appName, String userId, Map state) { + return Completable.fromAction( + () -> { + jdbcTemplate.inTransaction( + ops -> { + UserStateRow row = new UserStateRow(); + row.setAppName(appName); + row.setUserId(userId); + row.setState(toJson(state)); + row.setUpdateTime(Instant.now()); + + stateDao.upsertUserState(ops, row); + return null; + }); + }) + .subscribeOn(Schedulers.io()); + } + + private Session toDomainSession( + SessionRow row, + List eventRows, + Map appState, + Map userState) { + ConcurrentMap mergedState = new ConcurrentHashMap<>(); + + if (appState != null) { + for (Map.Entry entry : appState.entrySet()) { + mergedState.put(State.APP_PREFIX + entry.getKey(), entry.getValue()); + } + } + + if (userState != null) { + for (Map.Entry entry : userState.entrySet()) { + mergedState.put(State.USER_PREFIX + entry.getKey(), entry.getValue()); + } + } + + Map sessionStateMap = fromJson(row.getState()); + if (sessionStateMap != null) { + mergedState.putAll(sessionStateMap); + } + + List events = eventRows.stream().map(this::toEvent).collect(Collectors.toList()); + + return Session.builder(row.getId()) + .appName(row.getAppName()) + .userId(row.getUserId()) + .state(mergedState) + .events(events) + .lastUpdateTime(row.getUpdateTime()) + .build(); + } + + private Event toEvent(EventRow row) { + try { + Event event = objectMapper.readValue(row.getEventData(), Event.class); + + event.setId(row.getId()); + event.setInvocationId(row.getInvocationId()); + event.setTimestamp(row.getTimestamp().toEpochMilli()); + + return event; + } catch (Exception e) { + logger.error("Failed to deserialize event {}: {}", row.getId(), e.getMessage(), e); + throw new SessionException("Failed to convert EventRow to Event", e); + } + } + + private EventRow fromDomainEvent(Event event, String appName, String userId, String sessionId) { + EventRow row = new EventRow(); + row.setId(event.id()); + row.setAppName(appName); + row.setUserId(userId); + row.setSessionId(sessionId); + row.setInvocationId(event.invocationId()); + row.setTimestamp(Instant.ofEpochMilli(event.timestamp())); + + try { + Map eventDataMap = + objectMapper.convertValue(event, new TypeReference>() {}); + + eventDataMap.remove("id"); + eventDataMap.remove("invocationId"); + eventDataMap.remove("timestamp"); + + String eventDataJson = objectMapper.writeValueAsString(eventDataMap); + row.setEventData(eventDataJson); + } catch (Exception e) { + logger.error("Failed to serialize event {}: {}", event.id(), e.getMessage(), e); + throw new SessionException("Failed to convert Event to EventRow", e); + } + + return row; + } + + private String toJson(Map map) { + try { + return objectMapper.writeValueAsString(map); + } catch (Exception e) { + throw new SessionException("Failed to serialize to JSON", e); + } + } + + private Map fromJson(String json) { + if (json == null || json.isEmpty()) { + return new ConcurrentHashMap<>(); + } + try { + return objectMapper.readValue(json, new TypeReference>() {}); + } catch (Exception e) { + throw new SessionException("Failed to deserialize from JSON", e); + } + } + + private void checkNotClosed() { + if (closed.get()) { + throw new IllegalStateException( + "DatabaseSessionService is closed. Create a new instance or ensure close() is not called prematurely."); + } + } + + private static int getIntProperty(Map properties, String key, int defaultValue) { + Object value = properties.get(key); + if (value == null) { + return defaultValue; + } + if (value instanceof Integer) { + return (Integer) value; + } + if (value instanceof Number) { + return ((Number) value).intValue(); + } + if (value instanceof String) { + try { + return Integer.parseInt((String) value); + } catch (NumberFormatException e) { + logger.warn( + "Invalid integer value for property {}: {}. Using default: {}", + key, + value, + defaultValue); + return defaultValue; + } + } + logger.warn( + "Unsupported type for property {}: {}. Using default: {}", + key, + value.getClass().getName(), + defaultValue); + return defaultValue; + } + + private static long getLongProperty( + Map properties, String key, long defaultValue) { + Object value = properties.get(key); + if (value == null) { + return defaultValue; + } + if (value instanceof Long) { + return (Long) value; + } + if (value instanceof Number) { + return ((Number) value).longValue(); + } + if (value instanceof String) { + try { + return Long.parseLong((String) value); + } catch (NumberFormatException e) { + logger.warn( + "Invalid long value for property {}: {}. Using default: {}", key, value, defaultValue); + return defaultValue; + } + } + logger.warn( + "Unsupported type for property {}: {}. Using default: {}", + key, + value.getClass().getName(), + defaultValue); + return defaultValue; + } + + /** + * Helper method to create a session within a transaction. + * + *

Extracted from {@link #createSession} to improve testability and reduce lambda nesting. + * + * @param appName the application name + * @param userId the user identifier + * @param state initial state map (may be null) + * @param sessionId session ID (generates UUID if null/empty) + * @return the created Session + */ + private Session createSessionInTransaction( + String appName, String userId, ConcurrentMap state, String sessionId) + throws java.sql.SQLException { + String id = + (sessionId != null && !sessionId.isEmpty()) ? sessionId : UUID.randomUUID().toString(); + + Instant now = Instant.now(); + + return jdbcTemplate.inTransaction( + ops -> { + Map appStateMap = new ConcurrentHashMap<>(); + Map userStateMap = new ConcurrentHashMap<>(); + Map sessionStateMap = new ConcurrentHashMap<>(); + + if (state != null) { + for (Map.Entry entry : state.entrySet()) { + String key = entry.getKey(); + if (key.startsWith(State.TEMP_PREFIX)) { + continue; + } + + if (key.startsWith(State.APP_PREFIX)) { + String unprefixedKey = key.substring(State.APP_PREFIX.length()); + appStateMap.put(unprefixedKey, entry.getValue()); + } else if (key.startsWith(State.USER_PREFIX)) { + String unprefixedKey = key.substring(State.USER_PREFIX.length()); + userStateMap.put(unprefixedKey, entry.getValue()); + } else { + sessionStateMap.put(key, entry.getValue()); + } + } + } + + Map appState = upsertAppStateIfNeeded(ops, appName, appStateMap, now); + Map userState = + upsertUserStateIfNeeded(ops, appName, userId, userStateMap, now); + + SessionRow row = new SessionRow(); + row.setAppName(appName); + row.setUserId(userId); + row.setId(id); + row.setState(toJson(sessionStateMap)); + row.setCreateTime(now); + row.setUpdateTime(now); + + sessionDao.insertSession(ops, row); + + return toDomainSession(row, Collections.emptyList(), appState, userState); + }); + } + + /** + * Helper method to build a Session from a SessionRow within a transaction. + * + *

Extracted from {@link #getSession} to improve testability and reduce lambda nesting. + * + * @param ops transaction operations + * @param sessionRow the session row from database + * @param config optional configuration for event filtering + * @return the built Session + */ + private Session buildSessionFromRow( + JdbcTemplate.JdbcOperations ops, SessionRow sessionRow, Optional config) + throws java.sql.SQLException { + String appName = sessionRow.getAppName(); + String userId = sessionRow.getUserId(); + String sessionId = sessionRow.getId(); + + // Convert negative values to positive: -N becomes N (last N events) + Optional limit = config.flatMap(c -> c.numRecentEvents().map(Math::abs)); + + List eventRows; + if (config.isPresent() && config.get().afterTimestamp().isPresent()) { + Instant afterTimestamp = config.get().afterTimestamp().get(); + eventRows = + eventDao.listEventsAfterTimestamp( + ops, appName, userId, sessionId, afterTimestamp, limit, 0); + } else { + eventRows = eventDao.listEvents(ops, appName, userId, sessionId, limit); + } + + Optional appStateOpt = stateDao.getAppState(ops, appName); + Map appState = + appStateOpt.map(s -> fromJson(s.getState())).orElse(new ConcurrentHashMap<>()); + + Optional userStateOpt = stateDao.getUserState(ops, appName, userId); + Map userState = + userStateOpt.map(s -> fromJson(s.getState())).orElse(new ConcurrentHashMap<>()); + + return toDomainSession(sessionRow, eventRows, appState, userState); + } + + /** + * Helper method to upsert app state if needed. + * + * @param ops transaction operations + * @param appName application name + * @param appStateMap state map to upsert (may be empty) + * @param now current timestamp + * @return the merged app state map + */ + private Map upsertAppStateIfNeeded( + JdbcTemplate.JdbcOperations ops, String appName, Map appStateMap, Instant now) + throws java.sql.SQLException { + Optional appStateOpt = stateDao.getAppStateForUpdate(ops, appName); + Map appState; + + if (appStateOpt.isPresent()) { + appState = fromJson(appStateOpt.get().getState()); + if (!appStateMap.isEmpty()) { + appState.putAll(appStateMap); + AppStateRow updatedAppState = new AppStateRow(); + updatedAppState.setAppName(appName); + updatedAppState.setState(toJson(appState)); + updatedAppState.setUpdateTime(now); + stateDao.upsertAppState(ops, updatedAppState); + } + } else if (!appStateMap.isEmpty()) { + appState = new ConcurrentHashMap<>(appStateMap); + AppStateRow newAppState = new AppStateRow(); + newAppState.setAppName(appName); + newAppState.setState(toJson(appState)); + newAppState.setUpdateTime(now); + stateDao.upsertAppState(ops, newAppState); + } else { + appState = new ConcurrentHashMap<>(); + } + + return appState; + } + + /** + * Helper method to upsert user state if needed. + * + * @param ops transaction operations + * @param appName application name + * @param userId user identifier + * @param userStateMap state map to upsert (may be empty) + * @param now current timestamp + * @return the merged user state map + */ + private Map upsertUserStateIfNeeded( + JdbcTemplate.JdbcOperations ops, + String appName, + String userId, + Map userStateMap, + Instant now) + throws java.sql.SQLException { + Optional userStateOpt = stateDao.getUserStateForUpdate(ops, appName, userId); + Map userState; + + if (userStateOpt.isPresent()) { + userState = fromJson(userStateOpt.get().getState()); + if (!userStateMap.isEmpty()) { + userState.putAll(userStateMap); + UserStateRow updatedUserState = new UserStateRow(); + updatedUserState.setAppName(appName); + updatedUserState.setUserId(userId); + updatedUserState.setState(toJson(userState)); + updatedUserState.setUpdateTime(now); + stateDao.upsertUserState(ops, updatedUserState); + } + } else if (!userStateMap.isEmpty()) { + userState = new ConcurrentHashMap<>(userStateMap); + UserStateRow newUserState = new UserStateRow(); + newUserState.setAppName(appName); + newUserState.setUserId(userId); + newUserState.setState(toJson(userState)); + newUserState.setUpdateTime(now); + stateDao.upsertUserState(ops, newUserState); + } else { + userState = new ConcurrentHashMap<>(); + } + + return userState; + } + + /** + * Closes this service and releases all database connections. + * + *

This method shuts down the HikariCP connection pool and releases all associated resources. + * After calling this method, the service cannot be used again - create a new instance if needed. + * + *

This method is idempotent - calling it multiple times has no additional effect. + * + *

Thread Safety: This method uses atomic compare-and-set to ensure the connection pool + * is closed exactly once, even if called concurrently from multiple threads. + * + *

Best Practice: Use try-with-resources to ensure automatic cleanup: + * + *

{@code
+   * try (DatabaseSessionService service = new DatabaseSessionService(jdbcUrl)) {
+   *     // Use the service
+   * } // Automatically closed
+   * }
+ * + * @throws IllegalStateException if any operations are attempted after closing + */ + @Override + public void close() { + if (closed.compareAndSet(false, true)) { + logger.info("Closing DatabaseSessionService"); + if (dataSource != null && !dataSource.isClosed()) { + dataSource.close(); + logger.info("HikariCP connection pool closed"); + } + } + } +} diff --git a/contrib/database-session-service/src/main/java/com/google/adk/sessions/dao/EventDao.java b/contrib/database-session-service/src/main/java/com/google/adk/sessions/dao/EventDao.java new file mode 100644 index 000000000..79f49f5c4 --- /dev/null +++ b/contrib/database-session-service/src/main/java/com/google/adk/sessions/dao/EventDao.java @@ -0,0 +1,141 @@ +package com.google.adk.sessions.dao; + +import com.google.adk.sessions.dialect.SqlDialect; +import com.google.adk.sessions.model.EventRow; +import com.google.adk.sessions.util.JdbcTemplate.JdbcOperations; +import com.google.adk.sessions.util.RowMapper; +import java.sql.SQLException; +import java.sql.Timestamp; +import java.util.*; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class EventDao { + + private static final Logger logger = LoggerFactory.getLogger(EventDao.class); + private final SqlDialect dialect; + + public EventDao(SqlDialect dialect) { + this.dialect = dialect; + logger.debug("EventDao initialized with {} dialect", dialect.dialectName()); + } + + private static final RowMapper ROW_MAPPER = + rs -> { + EventRow row = new EventRow(); + row.setId(rs.getString("id")); + row.setAppName(rs.getString("app_name")); + row.setUserId(rs.getString("user_id")); + row.setSessionId(rs.getString("session_id")); + row.setInvocationId(rs.getString("invocation_id")); + row.setEventData(rs.getString("event_data")); + + Timestamp ts = rs.getTimestamp("timestamp"); + row.setTimestamp(ts != null ? ts.toInstant() : null); + + return row; + }; + + public List listEvents( + JdbcOperations ops, String appName, String userId, String sessionId) throws SQLException { + Map params = new HashMap<>(); + params.put("appName", appName); + params.put("userId", userId); + params.put("sessionId", sessionId); + + String sql = + "SELECT * FROM events " + + "WHERE app_name = :appName AND user_id = :userId AND session_id = :sessionId " + + "ORDER BY timestamp ASC"; + + return ops.query(sql, params, ROW_MAPPER); + } + + public List listEvents( + JdbcOperations ops, String appName, String userId, String sessionId, Optional limit) + throws SQLException { + Map params = new HashMap<>(); + params.put("appName", appName); + params.put("userId", userId); + params.put("sessionId", sessionId); + + String sql = + "SELECT * FROM events " + + "WHERE app_name = :appName AND user_id = :userId AND session_id = :sessionId " + + "ORDER BY timestamp DESC"; + + if (limit.isPresent()) { + sql += " LIMIT :limit"; + params.put("limit", limit.get()); + } + + List events = ops.query(sql, params, ROW_MAPPER); + Collections.reverse(events); + return events; + } + + public List listEventsAfterTimestamp( + JdbcOperations ops, + String appName, + String userId, + String sessionId, + java.time.Instant afterTimestamp, + Optional limit, + int offset) + throws SQLException { + Map params = new HashMap<>(); + params.put("appName", appName); + params.put("userId", userId); + params.put("sessionId", sessionId); + params.put("afterTimestamp", java.sql.Timestamp.from(afterTimestamp)); + + String sql = + "SELECT * FROM events " + + "WHERE app_name = :appName AND user_id = :userId AND session_id = :sessionId " + + "AND timestamp > :afterTimestamp " + + "ORDER BY timestamp ASC"; + + if (limit.isPresent()) { + sql += " LIMIT :limit OFFSET :offset"; + params.put("limit", limit.get()); + params.put("offset", offset); + } + + return ops.query(sql, params, ROW_MAPPER); + } + + public void insertEvent(JdbcOperations ops, EventRow event) throws SQLException { + String sql = + "INSERT INTO events (id, app_name, user_id, session_id, invocation_id, timestamp, event_data) " + + "VALUES (:id, :appName, :userId, :sessionId, :invocationId, :timestamp, " + + dialect.jsonValue(":eventData") + + ")"; + + Map params = new HashMap<>(); + params.put("id", event.getId()); + params.put("appName", event.getAppName()); + params.put("userId", event.getUserId()); + params.put("sessionId", event.getSessionId()); + params.put("invocationId", event.getInvocationId()); + params.put("timestamp", Timestamp.from(event.getTimestamp())); + params.put("eventData", event.getEventData()); + + logger.debug("Appending event: eventId={}, sessionId={}", event.getId(), event.getSessionId()); + + ops.update(sql, params); + } + + public long countEvents(JdbcOperations ops, String appName, String userId, String sessionId) + throws SQLException { + String sql = + "SELECT COUNT(*) as count FROM events " + + "WHERE app_name = :appName AND user_id = :userId AND session_id = :sessionId"; + + Map params = new HashMap<>(); + params.put("appName", appName); + params.put("userId", userId); + params.put("sessionId", sessionId); + + return ops.queryForObject(sql, params, rs -> rs.getLong("count")).orElse(0L); + } +} diff --git a/contrib/database-session-service/src/main/java/com/google/adk/sessions/dao/SessionDao.java b/contrib/database-session-service/src/main/java/com/google/adk/sessions/dao/SessionDao.java new file mode 100644 index 000000000..71e664390 --- /dev/null +++ b/contrib/database-session-service/src/main/java/com/google/adk/sessions/dao/SessionDao.java @@ -0,0 +1,131 @@ +package com.google.adk.sessions.dao; + +import com.google.adk.sessions.dialect.SqlDialect; +import com.google.adk.sessions.model.SessionRow; +import com.google.adk.sessions.util.JdbcTemplate.JdbcOperations; +import com.google.adk.sessions.util.RowMapper; +import java.sql.SQLException; +import java.sql.Timestamp; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class SessionDao { + + private static final Logger logger = LoggerFactory.getLogger(SessionDao.class); + private final SqlDialect dialect; + + public SessionDao(SqlDialect dialect) { + this.dialect = dialect; + logger.debug("SessionDao initialized with {} dialect", dialect.dialectName()); + } + + private static final RowMapper ROW_MAPPER = + rs -> { + SessionRow row = new SessionRow(); + row.setAppName(rs.getString("app_name")); + row.setUserId(rs.getString("user_id")); + row.setId(rs.getString("id")); + row.setState(rs.getString("state")); + + Timestamp createTs = rs.getTimestamp("create_time"); + row.setCreateTime(createTs != null ? createTs.toInstant() : null); + + Timestamp updateTs = rs.getTimestamp("update_time"); + row.setUpdateTime(updateTs != null ? updateTs.toInstant() : null); + + return row; + }; + + public Optional findSession( + JdbcOperations ops, String appName, String userId, String id) throws SQLException { + String sql = + "SELECT * FROM sessions " + "WHERE app_name = :appName AND user_id = :userId AND id = :id"; + + Map params = new HashMap<>(); + params.put("appName", appName); + params.put("userId", userId); + params.put("id", id); + + return ops.queryForObject(sql, params, ROW_MAPPER); + } + + public Optional findSessionForUpdate( + JdbcOperations ops, String appName, String userId, String id) throws SQLException { + String sql = + "SELECT * FROM sessions " + + "WHERE app_name = :appName AND user_id = :userId AND id = :id " + + dialect.forUpdateSyntax(); + + Map params = new HashMap<>(); + params.put("appName", appName); + params.put("userId", userId); + params.put("id", id); + + return ops.queryForObject(sql, params, ROW_MAPPER); + } + + public List listSessions(JdbcOperations ops, String appName, String userId) + throws SQLException { + String sql = + "SELECT * FROM sessions " + + "WHERE app_name = :appName AND user_id = :userId " + + "ORDER BY update_time DESC"; + + Map params = new HashMap<>(); + params.put("appName", appName); + params.put("userId", userId); + + return ops.query(sql, params, ROW_MAPPER); + } + + public void insertSession(JdbcOperations ops, SessionRow session) throws SQLException { + String sql = + "INSERT INTO sessions (app_name, user_id, id, state, create_time, update_time) " + + "VALUES (:appName, :userId, :id, " + + dialect.jsonValue(":state") + + ", :createTime, :updateTime)"; + + Map params = new HashMap<>(); + params.put("appName", session.getAppName()); + params.put("userId", session.getUserId()); + params.put("id", session.getId()); + params.put("state", session.getState()); + params.put("createTime", Timestamp.from(session.getCreateTime())); + params.put("updateTime", Timestamp.from(session.getUpdateTime())); + + logger.debug( + "Inserting session: app={}, user={}, sessionId={}", + session.getAppName(), + session.getUserId(), + session.getId()); + ops.update(sql, params); + logger.debug("Session created successfully: {}", session.getId()); + } + + public void updateSession(JdbcOperations ops, SessionRow session) throws SQLException { + String sql = + "UPDATE sessions " + + "SET state = " + + dialect.jsonValue(":state") + + ", update_time = :updateTime " + + "WHERE app_name = :appName AND user_id = :userId AND id = :id"; + + Map params = new HashMap<>(); + params.put("state", session.getState()); + params.put("updateTime", Timestamp.from(session.getUpdateTime())); + params.put("appName", session.getAppName()); + params.put("userId", session.getUserId()); + params.put("id", session.getId()); + + ops.update(sql, params); + } + + public void deleteSession(JdbcOperations ops, String appName, String userId, String id) + throws SQLException { + dialect.deleteSession(ops, appName, userId, id); + } +} diff --git a/contrib/database-session-service/src/main/java/com/google/adk/sessions/dao/StateDao.java b/contrib/database-session-service/src/main/java/com/google/adk/sessions/dao/StateDao.java new file mode 100644 index 000000000..1a8b672f9 --- /dev/null +++ b/contrib/database-session-service/src/main/java/com/google/adk/sessions/dao/StateDao.java @@ -0,0 +1,110 @@ +package com.google.adk.sessions.dao; + +import com.google.adk.sessions.dialect.SqlDialect; +import com.google.adk.sessions.model.AppStateRow; +import com.google.adk.sessions.model.UserStateRow; +import com.google.adk.sessions.util.JdbcTemplate.JdbcOperations; +import com.google.adk.sessions.util.RowMapper; +import java.sql.SQLException; +import java.sql.Timestamp; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class StateDao { + + private static final Logger logger = LoggerFactory.getLogger(StateDao.class); + private final SqlDialect dialect; + + public StateDao(SqlDialect dialect) { + this.dialect = dialect; + } + + private static final RowMapper APP_STATE_MAPPER = + rs -> { + AppStateRow row = new AppStateRow(); + row.setAppName(rs.getString("app_name")); + row.setState(rs.getString("state")); + + Timestamp updateTs = rs.getTimestamp("update_time"); + row.setUpdateTime(updateTs != null ? updateTs.toInstant() : null); + + return row; + }; + + private static final RowMapper USER_STATE_MAPPER = + rs -> { + UserStateRow row = new UserStateRow(); + row.setAppName(rs.getString("app_name")); + row.setUserId(rs.getString("user_id")); + row.setState(rs.getString("state")); + + Timestamp updateTs = rs.getTimestamp("update_time"); + row.setUpdateTime(updateTs != null ? updateTs.toInstant() : null); + + return row; + }; + + public Optional getAppState(JdbcOperations ops, String appName) throws SQLException { + String sql = "SELECT * FROM app_states WHERE app_name = :appName"; + + Map params = new HashMap<>(); + params.put("appName", appName); + + return ops.queryForObject(sql, params, APP_STATE_MAPPER); + } + + public Optional getAppStateForUpdate(JdbcOperations ops, String appName) + throws SQLException { + String sql = "SELECT * FROM app_states WHERE app_name = :appName " + dialect.forUpdateSyntax(); + + Map params = new HashMap<>(); + params.put("appName", appName); + + return ops.queryForObject(sql, params, APP_STATE_MAPPER); + } + + public void upsertAppState(JdbcOperations ops, AppStateRow appState) throws SQLException { + logger.debug("Upserting app state for app: {}", appState.getAppName()); + dialect.upsertAppState(ops, appState); + logger.debug("App state upserted successfully for app: {}", appState.getAppName()); + } + + public Optional getUserState(JdbcOperations ops, String appName, String userId) + throws SQLException { + String sql = "SELECT * FROM user_states WHERE app_name = :appName AND user_id = :userId"; + + Map params = new HashMap<>(); + params.put("appName", appName); + params.put("userId", userId); + + return ops.queryForObject(sql, params, USER_STATE_MAPPER); + } + + public Optional getUserStateForUpdate( + JdbcOperations ops, String appName, String userId) throws SQLException { + String sql = + "SELECT * FROM user_states WHERE app_name = :appName AND user_id = :userId " + + dialect.forUpdateSyntax(); + + Map params = new HashMap<>(); + params.put("appName", appName); + params.put("userId", userId); + + return ops.queryForObject(sql, params, USER_STATE_MAPPER); + } + + public void upsertUserState(JdbcOperations ops, UserStateRow userState) throws SQLException { + logger.debug( + "Upserting user state for app: {}, user: {}", + userState.getAppName(), + userState.getUserId()); + dialect.upsertUserState(ops, userState); + logger.debug( + "User state upserted successfully for app: {}, user: {}", + userState.getAppName(), + userState.getUserId()); + } +} diff --git a/contrib/database-session-service/src/main/java/com/google/adk/sessions/dialect/DialectDetector.java b/contrib/database-session-service/src/main/java/com/google/adk/sessions/dialect/DialectDetector.java new file mode 100644 index 000000000..0aa7afe77 --- /dev/null +++ b/contrib/database-session-service/src/main/java/com/google/adk/sessions/dialect/DialectDetector.java @@ -0,0 +1,45 @@ +package com.google.adk.sessions.dialect; + +import java.sql.Connection; +import java.sql.DatabaseMetaData; +import java.sql.SQLException; + +public class DialectDetector { + + public static SqlDialect detect(Connection connection) throws SQLException { + DatabaseMetaData metaData = connection.getMetaData(); + String productName = metaData.getDatabaseProductName().toLowerCase(); + + if (productName.contains("postgresql")) { + return new PostgresDialect(); + } else if (productName.contains("mysql")) { + return new MySqlDialect(); + } else if (productName.contains("h2")) { + return new H2Dialect(); + } else if (productName.contains("spanner")) { + return new SpannerDialect(); + } else { + throw new IllegalArgumentException( + "Unsupported database: " + + productName + + ". " + + "Supported databases: PostgreSQL, MySQL, H2, Cloud Spanner"); + } + } + + public static SqlDialect detectFromJdbcUrl(String jdbcUrl) { + String url = jdbcUrl.toLowerCase(); + + if (url.startsWith("jdbc:postgresql:")) { + return new PostgresDialect(); + } else if (url.startsWith("jdbc:mysql:")) { + return new MySqlDialect(); + } else if (url.startsWith("jdbc:h2:")) { + return new H2Dialect(); + } else if (url.startsWith("jdbc:cloudspanner:")) { + return new SpannerDialect(); + } else { + throw new IllegalArgumentException("Cannot detect dialect from JDBC URL: " + jdbcUrl); + } + } +} diff --git a/contrib/database-session-service/src/main/java/com/google/adk/sessions/dialect/H2Dialect.java b/contrib/database-session-service/src/main/java/com/google/adk/sessions/dialect/H2Dialect.java new file mode 100644 index 000000000..d75a66dc7 --- /dev/null +++ b/contrib/database-session-service/src/main/java/com/google/adk/sessions/dialect/H2Dialect.java @@ -0,0 +1,33 @@ +package com.google.adk.sessions.dialect; + +public class H2Dialect implements SqlDialect { + + @Override + public String dialectName() { + return "H2"; + } + + @Override + public String jsonCastSyntax() { + return ""; + } + + @Override + public String forUpdateSyntax() { + return "FOR UPDATE"; + } + + @Override + public String upsertAppStateSql() { + return "MERGE INTO app_states (app_name, state, update_time) " + + "KEY (app_name) " + + "VALUES (:appName, :state, :updateTime)"; + } + + @Override + public String upsertUserStateSql() { + return "MERGE INTO user_states (app_name, user_id, state, update_time) " + + "KEY (app_name, user_id) " + + "VALUES (:appName, :userId, :state, :updateTime)"; + } +} diff --git a/contrib/database-session-service/src/main/java/com/google/adk/sessions/dialect/MySqlDialect.java b/contrib/database-session-service/src/main/java/com/google/adk/sessions/dialect/MySqlDialect.java new file mode 100644 index 000000000..b30e37d78 --- /dev/null +++ b/contrib/database-session-service/src/main/java/com/google/adk/sessions/dialect/MySqlDialect.java @@ -0,0 +1,33 @@ +package com.google.adk.sessions.dialect; + +public class MySqlDialect implements SqlDialect { + + @Override + public String dialectName() { + return "MySQL"; + } + + @Override + public String jsonCastSyntax() { + return ""; + } + + @Override + public String forUpdateSyntax() { + return "FOR UPDATE"; + } + + @Override + public String upsertAppStateSql() { + return "INSERT INTO app_states (app_name, state, update_time) " + + "VALUES (:appName, :state, :updateTime) " + + "ON DUPLICATE KEY UPDATE state = VALUES(state), update_time = VALUES(update_time)"; + } + + @Override + public String upsertUserStateSql() { + return "INSERT INTO user_states (app_name, user_id, state, update_time) " + + "VALUES (:appName, :userId, :state, :updateTime) " + + "ON DUPLICATE KEY UPDATE state = VALUES(state), update_time = VALUES(update_time)"; + } +} diff --git a/contrib/database-session-service/src/main/java/com/google/adk/sessions/dialect/PostgresDialect.java b/contrib/database-session-service/src/main/java/com/google/adk/sessions/dialect/PostgresDialect.java new file mode 100644 index 000000000..c69c8f1b1 --- /dev/null +++ b/contrib/database-session-service/src/main/java/com/google/adk/sessions/dialect/PostgresDialect.java @@ -0,0 +1,35 @@ +package com.google.adk.sessions.dialect; + +public class PostgresDialect implements SqlDialect { + + @Override + public String dialectName() { + return "PostgreSQL"; + } + + @Override + public String jsonCastSyntax() { + return "::jsonb"; + } + + @Override + public String forUpdateSyntax() { + return "FOR UPDATE"; + } + + @Override + public String upsertAppStateSql() { + return "INSERT INTO app_states (app_name, state, update_time) " + + "VALUES (:appName, :state::jsonb, :updateTime) " + + "ON CONFLICT (app_name) " + + "DO UPDATE SET state = EXCLUDED.state, update_time = EXCLUDED.update_time"; + } + + @Override + public String upsertUserStateSql() { + return "INSERT INTO user_states (app_name, user_id, state, update_time) " + + "VALUES (:appName, :userId, :state::jsonb, :updateTime) " + + "ON CONFLICT (app_name, user_id) " + + "DO UPDATE SET state = EXCLUDED.state, update_time = EXCLUDED.update_time"; + } +} diff --git a/contrib/database-session-service/src/main/java/com/google/adk/sessions/dialect/SpannerDialect.java b/contrib/database-session-service/src/main/java/com/google/adk/sessions/dialect/SpannerDialect.java new file mode 100644 index 000000000..b163a0013 --- /dev/null +++ b/contrib/database-session-service/src/main/java/com/google/adk/sessions/dialect/SpannerDialect.java @@ -0,0 +1,65 @@ +package com.google.adk.sessions.dialect; + +import com.google.adk.sessions.util.JdbcTemplate.JdbcOperations; +import java.sql.SQLException; +import java.util.HashMap; +import java.util.Map; + +public class SpannerDialect implements SqlDialect { + + @Override + public String dialectName() { + return "Cloud Spanner"; + } + + @Override + public String jsonCastSyntax() { + return ""; + } + + @Override + public String jsonValue(String paramName) { + return "PARSE_JSON(" + paramName + ")"; + } + + @Override + public String forUpdateSyntax() { + return ""; + } + + @Override + public String upsertAppStateSql() { + return "INSERT OR UPDATE app_states (app_name, state, update_time) " + + "VALUES (:appName, " + + jsonValue(":state") + + ", :updateTime)"; + } + + @Override + public String upsertUserStateSql() { + return "INSERT OR UPDATE user_states (app_name, user_id, state, update_time) " + + "VALUES (:appName, :userId, " + + jsonValue(":state") + + ", :updateTime)"; + } + + @Override + public void deleteSession(JdbcOperations ops, String appName, String userId, String sessionId) + throws SQLException { + String deleteEventsSql = + "DELETE FROM events WHERE app_name = :appName AND user_id = :userId AND session_id = :sessionId"; + Map eventsParams = new HashMap<>(); + eventsParams.put("appName", appName); + eventsParams.put("userId", userId); + eventsParams.put("sessionId", sessionId); + ops.update(deleteEventsSql, eventsParams); + + String deleteSessionSql = + "DELETE FROM sessions WHERE app_name = :appName AND user_id = :userId AND id = :id"; + Map params = new HashMap<>(); + params.put("appName", appName); + params.put("userId", userId); + params.put("id", sessionId); + ops.update(deleteSessionSql, params); + } +} diff --git a/contrib/database-session-service/src/main/java/com/google/adk/sessions/dialect/SqlDialect.java b/contrib/database-session-service/src/main/java/com/google/adk/sessions/dialect/SqlDialect.java new file mode 100644 index 000000000..b9089033e --- /dev/null +++ b/contrib/database-session-service/src/main/java/com/google/adk/sessions/dialect/SqlDialect.java @@ -0,0 +1,56 @@ +package com.google.adk.sessions.dialect; + +import com.google.adk.sessions.model.AppStateRow; +import com.google.adk.sessions.model.UserStateRow; +import com.google.adk.sessions.util.JdbcTemplate.JdbcOperations; +import java.sql.SQLException; +import java.sql.Timestamp; +import java.util.HashMap; +import java.util.Map; + +public interface SqlDialect { + + String dialectName(); + + String jsonCastSyntax(); + + String forUpdateSyntax(); + + String upsertAppStateSql(); + + String upsertUserStateSql(); + + default String jsonValue(String paramName) { + return paramName + jsonCastSyntax(); + } + + default void upsertAppState(JdbcOperations ops, AppStateRow appState) throws SQLException { + String sql = upsertAppStateSql(); + Map params = new HashMap<>(); + params.put("appName", appState.getAppName()); + params.put("state", appState.getState()); + params.put("updateTime", Timestamp.from(appState.getUpdateTime())); + ops.update(sql, params); + } + + default void upsertUserState(JdbcOperations ops, UserStateRow userState) throws SQLException { + String sql = upsertUserStateSql(); + Map params = new HashMap<>(); + params.put("appName", userState.getAppName()); + params.put("userId", userState.getUserId()); + params.put("state", userState.getState()); + params.put("updateTime", Timestamp.from(userState.getUpdateTime())); + ops.update(sql, params); + } + + default void deleteSession(JdbcOperations ops, String appName, String userId, String sessionId) + throws SQLException { + String sql = + "DELETE FROM sessions WHERE app_name = :appName AND user_id = :userId AND id = :id"; + Map params = new HashMap<>(); + params.put("appName", appName); + params.put("userId", userId); + params.put("id", sessionId); + ops.update(sql, params); + } +} diff --git a/contrib/database-session-service/src/main/java/com/google/adk/sessions/model/AppStateRow.java b/contrib/database-session-service/src/main/java/com/google/adk/sessions/model/AppStateRow.java new file mode 100644 index 000000000..8060a0416 --- /dev/null +++ b/contrib/database-session-service/src/main/java/com/google/adk/sessions/model/AppStateRow.java @@ -0,0 +1,33 @@ +package com.google.adk.sessions.model; + +import java.time.Instant; + +public class AppStateRow { + private String appName; + private String state; + private Instant updateTime; + + public String getAppName() { + return appName; + } + + public void setAppName(String appName) { + this.appName = appName; + } + + public String getState() { + return state; + } + + public void setState(String state) { + this.state = state; + } + + public Instant getUpdateTime() { + return updateTime; + } + + public void setUpdateTime(Instant updateTime) { + this.updateTime = updateTime; + } +} diff --git a/contrib/database-session-service/src/main/java/com/google/adk/sessions/model/EventRow.java b/contrib/database-session-service/src/main/java/com/google/adk/sessions/model/EventRow.java new file mode 100644 index 000000000..607c5df9d --- /dev/null +++ b/contrib/database-session-service/src/main/java/com/google/adk/sessions/model/EventRow.java @@ -0,0 +1,69 @@ +package com.google.adk.sessions.model; + +import java.time.Instant; + +public class EventRow { + private String id; + private String appName; + private String userId; + private String sessionId; + private String invocationId; + private Instant timestamp; + private String eventData; + + public String getId() { + return id; + } + + public void setId(String id) { + this.id = id; + } + + public String getAppName() { + return appName; + } + + public void setAppName(String appName) { + this.appName = appName; + } + + public String getUserId() { + return userId; + } + + public void setUserId(String userId) { + this.userId = userId; + } + + public String getSessionId() { + return sessionId; + } + + public void setSessionId(String sessionId) { + this.sessionId = sessionId; + } + + public String getInvocationId() { + return invocationId; + } + + public void setInvocationId(String invocationId) { + this.invocationId = invocationId; + } + + public Instant getTimestamp() { + return timestamp; + } + + public void setTimestamp(Instant timestamp) { + this.timestamp = timestamp; + } + + public String getEventData() { + return eventData; + } + + public void setEventData(String eventData) { + this.eventData = eventData; + } +} diff --git a/contrib/database-session-service/src/main/java/com/google/adk/sessions/model/SessionRow.java b/contrib/database-session-service/src/main/java/com/google/adk/sessions/model/SessionRow.java new file mode 100644 index 000000000..0dcf1a97e --- /dev/null +++ b/contrib/database-session-service/src/main/java/com/google/adk/sessions/model/SessionRow.java @@ -0,0 +1,60 @@ +package com.google.adk.sessions.model; + +import java.time.Instant; + +public class SessionRow { + private String appName; + private String userId; + private String id; + private String state; + private Instant createTime; + private Instant updateTime; + + public String getAppName() { + return appName; + } + + public void setAppName(String appName) { + this.appName = appName; + } + + public String getUserId() { + return userId; + } + + public void setUserId(String userId) { + this.userId = userId; + } + + public String getId() { + return id; + } + + public void setId(String id) { + this.id = id; + } + + public String getState() { + return state; + } + + public void setState(String state) { + this.state = state; + } + + public Instant getCreateTime() { + return createTime; + } + + public void setCreateTime(Instant createTime) { + this.createTime = createTime; + } + + public Instant getUpdateTime() { + return updateTime; + } + + public void setUpdateTime(Instant updateTime) { + this.updateTime = updateTime; + } +} diff --git a/contrib/database-session-service/src/main/java/com/google/adk/sessions/model/UserStateRow.java b/contrib/database-session-service/src/main/java/com/google/adk/sessions/model/UserStateRow.java new file mode 100644 index 000000000..bd02f0ddb --- /dev/null +++ b/contrib/database-session-service/src/main/java/com/google/adk/sessions/model/UserStateRow.java @@ -0,0 +1,42 @@ +package com.google.adk.sessions.model; + +import java.time.Instant; + +public class UserStateRow { + private String appName; + private String userId; + private String state; + private Instant updateTime; + + public String getAppName() { + return appName; + } + + public void setAppName(String appName) { + this.appName = appName; + } + + public String getUserId() { + return userId; + } + + public void setUserId(String userId) { + this.userId = userId; + } + + public String getState() { + return state; + } + + public void setState(String state) { + this.state = state; + } + + public Instant getUpdateTime() { + return updateTime; + } + + public void setUpdateTime(Instant updateTime) { + this.updateTime = updateTime; + } +} diff --git a/contrib/database-session-service/src/main/java/com/google/adk/sessions/util/JdbcTemplate.java b/contrib/database-session-service/src/main/java/com/google/adk/sessions/util/JdbcTemplate.java new file mode 100644 index 000000000..b0aa3ee66 --- /dev/null +++ b/contrib/database-session-service/src/main/java/com/google/adk/sessions/util/JdbcTemplate.java @@ -0,0 +1,101 @@ +package com.google.adk.sessions.util; + +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import javax.sql.DataSource; + +public class JdbcTemplate { + + private final DataSource dataSource; + + public JdbcTemplate(DataSource dataSource) { + this.dataSource = dataSource; + } + + public T inTransaction(TransactionCallback callback) throws SQLException { + try (Connection conn = dataSource.getConnection()) { + boolean originalAutoCommit = conn.getAutoCommit(); + try { + conn.setAutoCommit(false); + T result = callback.doInTransaction(new JdbcOperations(conn)); + conn.commit(); + return result; + } catch (Exception e) { + conn.rollback(); + throw e; + } finally { + conn.setAutoCommit(originalAutoCommit); + } + } + } + + @FunctionalInterface + public interface TransactionCallback { + T doInTransaction(JdbcOperations ops) throws SQLException; + } + + public static class JdbcOperations { + private final Connection connection; + + JdbcOperations(Connection connection) { + this.connection = connection; + } + + public Connection getConnection() { + return connection; + } + + public Optional queryForObject( + String sql, Map params, RowMapper mapper) throws SQLException { + NamedParameterSupport nps = NamedParameterSupport.parse(sql); + + try (PreparedStatement ps = connection.prepareStatement(nps.getParsedSql())) { + nps.setParameters(ps, params); + + try (ResultSet rs = ps.executeQuery()) { + if (rs.next()) { + return Optional.of(mapper.mapRow(rs)); + } + return Optional.empty(); + } + } + } + + public List query(String sql, Map params, RowMapper mapper) + throws SQLException { + NamedParameterSupport nps = NamedParameterSupport.parse(sql); + List results = new ArrayList<>(); + + try (PreparedStatement ps = connection.prepareStatement(nps.getParsedSql())) { + nps.setParameters(ps, params); + + try (ResultSet rs = ps.executeQuery()) { + while (rs.next()) { + results.add(mapper.mapRow(rs)); + } + } + } + + return results; + } + + public int update(String sql, Map params) throws SQLException { + NamedParameterSupport nps = NamedParameterSupport.parse(sql); + + try (PreparedStatement ps = connection.prepareStatement(nps.getParsedSql())) { + nps.setParameters(ps, params); + return ps.executeUpdate(); + } + } + + public int execute(String sql, Map params) throws SQLException { + return update(sql, params); + } + } +} diff --git a/contrib/database-session-service/src/main/java/com/google/adk/sessions/util/NamedParameterSupport.java b/contrib/database-session-service/src/main/java/com/google/adk/sessions/util/NamedParameterSupport.java new file mode 100644 index 000000000..4c4ab705d --- /dev/null +++ b/contrib/database-session-service/src/main/java/com/google/adk/sessions/util/NamedParameterSupport.java @@ -0,0 +1,58 @@ +package com.google.adk.sessions.util; + +import java.sql.PreparedStatement; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +public class NamedParameterSupport { + + private static final Pattern NAMED_PARAM_PATTERN = Pattern.compile("(? parameterNames; + + private NamedParameterSupport(String parsedSql, List parameterNames) { + this.parsedSql = parsedSql; + this.parameterNames = parameterNames; + } + + public static NamedParameterSupport parse(String namedSql) { + List parameterNames = new ArrayList<>(); + Matcher matcher = NAMED_PARAM_PATTERN.matcher(namedSql); + StringBuffer parsedSql = new StringBuffer(); + + while (matcher.find()) { + String paramName = matcher.group(1); + parameterNames.add(paramName); + matcher.appendReplacement(parsedSql, "?"); + } + matcher.appendTail(parsedSql); + + return new NamedParameterSupport(parsedSql.toString(), parameterNames); + } + + public String getParsedSql() { + return parsedSql; + } + + public void setParameters(PreparedStatement ps, Map params) throws SQLException { + for (int i = 0; i < parameterNames.size(); i++) { + String paramName = parameterNames.get(i); + + if (!params.containsKey(paramName)) { + throw new IllegalArgumentException("Missing parameter: " + paramName); + } + + Object value = params.get(paramName); + ps.setObject(i + 1, value); + } + } + + public List getParameterNames() { + return new ArrayList<>(parameterNames); + } +} diff --git a/contrib/database-session-service/src/main/java/com/google/adk/sessions/util/RowMapper.java b/contrib/database-session-service/src/main/java/com/google/adk/sessions/util/RowMapper.java new file mode 100644 index 000000000..312c67d60 --- /dev/null +++ b/contrib/database-session-service/src/main/java/com/google/adk/sessions/util/RowMapper.java @@ -0,0 +1,9 @@ +package com.google.adk.sessions.util; + +import java.sql.ResultSet; +import java.sql.SQLException; + +@FunctionalInterface +public interface RowMapper { + T mapRow(ResultSet rs) throws SQLException; +} diff --git a/contrib/database-session-service/src/main/resources/db/migration/h2/V1__Initial_schema.sql b/contrib/database-session-service/src/main/resources/db/migration/h2/V1__Initial_schema.sql new file mode 100644 index 000000000..bbb3bdca6 --- /dev/null +++ b/contrib/database-session-service/src/main/resources/db/migration/h2/V1__Initial_schema.sql @@ -0,0 +1,66 @@ +-- V1__Initial_schema.sql for H2 Database +-- Initial database schema for ADK DatabaseSessionService (v1 format) +-- This schema matches Python ADK v1 with simplified event storage using CLOB + +-- Create metadata table for schema versioning +CREATE TABLE IF NOT EXISTS adk_internal_metadata ( + "KEY" VARCHAR(128) PRIMARY KEY, + "VALUE" VARCHAR(256) +); + +-- Insert schema version (1 = v1 CLOB schema format, compatible with Python ADK) +MERGE INTO adk_internal_metadata ("KEY", "VALUE") VALUES ('schema_version', '1'); + +-- Create sessions table +CREATE TABLE IF NOT EXISTS sessions ( + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + id VARCHAR(128) NOT NULL, + state CLOB, + create_time TIMESTAMP(6), + update_time TIMESTAMP(6), + PRIMARY KEY (app_name, user_id, id) +); + +-- Create events table (v1 format with event_data CLOB column) +CREATE TABLE IF NOT EXISTS events ( + id VARCHAR(128) NOT NULL, + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + session_id VARCHAR(128) NOT NULL, + invocation_id VARCHAR(256), + timestamp TIMESTAMP(6), + event_data CLOB, + PRIMARY KEY (id, app_name, user_id, session_id), + FOREIGN KEY (app_name, user_id, session_id) + REFERENCES sessions(app_name, user_id, id) + ON DELETE CASCADE +); + +-- Create app states table +CREATE TABLE IF NOT EXISTS app_states ( + app_name VARCHAR(128) NOT NULL, + state CLOB, + update_time TIMESTAMP(6), + PRIMARY KEY (app_name) +); + +-- Create user states table +CREATE TABLE IF NOT EXISTS user_states ( + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + state CLOB, + update_time TIMESTAMP(6), + PRIMARY KEY (app_name, user_id) +); + +-- Add indexes to improve query performance + +-- Index for looking up sessions by app_name and user_id +CREATE INDEX IF NOT EXISTS idx_sessions_app_user ON sessions(app_name, user_id); + +-- Index for looking up events by session +CREATE INDEX IF NOT EXISTS idx_events_session ON events(app_name, user_id, session_id); + +-- Index for sorting events by timestamp +CREATE INDEX IF NOT EXISTS idx_events_timestamp ON events(timestamp); diff --git a/contrib/database-session-service/src/main/resources/db/migration/mysql/V1__Initial_schema.sql b/contrib/database-session-service/src/main/resources/db/migration/mysql/V1__Initial_schema.sql new file mode 100644 index 000000000..6d8184799 --- /dev/null +++ b/contrib/database-session-service/src/main/resources/db/migration/mysql/V1__Initial_schema.sql @@ -0,0 +1,66 @@ +-- V1__Initial_schema.sql for MySQL +-- Initial database schema for ADK DatabaseSessionService (v1 format) +-- This schema matches Python ADK v1 with simplified event storage using JSON + +-- Create metadata table for schema versioning +CREATE TABLE IF NOT EXISTS adk_internal_metadata ( + `key` VARCHAR(128) PRIMARY KEY, + value VARCHAR(256) +); + +-- Insert schema version (1 = v1 JSON schema format, compatible with Python ADK) +INSERT INTO adk_internal_metadata (`key`, value) VALUES ('schema_version', '1'); + +-- Create sessions table +CREATE TABLE IF NOT EXISTS sessions ( + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + id VARCHAR(128) NOT NULL, + state LONGTEXT, + create_time TIMESTAMP(6), + update_time TIMESTAMP(6), + PRIMARY KEY (app_name, user_id, id) +); + +-- Create events table (v1 format with event_data JSON column) +CREATE TABLE IF NOT EXISTS events ( + id VARCHAR(128) NOT NULL, + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + session_id VARCHAR(128) NOT NULL, + invocation_id VARCHAR(256), + timestamp TIMESTAMP(6), + event_data LONGTEXT, + PRIMARY KEY (id, app_name, user_id, session_id), + FOREIGN KEY (app_name, user_id, session_id) + REFERENCES sessions(app_name, user_id, id) + ON DELETE CASCADE +); + +-- Create app states table +CREATE TABLE IF NOT EXISTS app_states ( + app_name VARCHAR(128) NOT NULL, + state LONGTEXT, + update_time TIMESTAMP(6), + PRIMARY KEY (app_name) +); + +-- Create user states table +CREATE TABLE IF NOT EXISTS user_states ( + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + state LONGTEXT, + update_time TIMESTAMP(6), + PRIMARY KEY (app_name, user_id) +); + +-- Add indexes to improve query performance + +-- Index for looking up sessions by app_name and user_id +CREATE INDEX idx_sessions_app_user ON sessions(app_name, user_id); + +-- Index for looking up events by session +CREATE INDEX idx_events_session ON events(app_name, user_id, session_id); + +-- Index for sorting events by timestamp +CREATE INDEX idx_events_timestamp ON events(timestamp); diff --git a/contrib/database-session-service/src/main/resources/db/migration/postgresql/V1__Initial_schema.sql b/contrib/database-session-service/src/main/resources/db/migration/postgresql/V1__Initial_schema.sql new file mode 100644 index 000000000..810a2eb87 --- /dev/null +++ b/contrib/database-session-service/src/main/resources/db/migration/postgresql/V1__Initial_schema.sql @@ -0,0 +1,66 @@ +-- V1__Initial_schema.sql for PostgreSQL +-- Initial database schema for ADK DatabaseSessionService (v1 format) +-- This schema matches Python ADK v1 with simplified event storage using JSON + +-- Create metadata table for schema versioning +CREATE TABLE IF NOT EXISTS adk_internal_metadata ( + key VARCHAR(128) PRIMARY KEY, + value VARCHAR(256) +); + +-- Insert schema version (1 = v1 JSON schema format, compatible with Python ADK) +INSERT INTO adk_internal_metadata (key, value) VALUES ('schema_version', '1'); + +-- Create sessions table +CREATE TABLE IF NOT EXISTS sessions ( + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + id VARCHAR(128) NOT NULL, + state JSONB, + create_time TIMESTAMP, + update_time TIMESTAMP, + PRIMARY KEY (app_name, user_id, id) +); + +-- Create events table (v1 format with event_data JSON column) +CREATE TABLE IF NOT EXISTS events ( + id VARCHAR(128) NOT NULL, + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + session_id VARCHAR(128) NOT NULL, + invocation_id VARCHAR(256), + timestamp TIMESTAMP, + event_data JSONB, + PRIMARY KEY (id, app_name, user_id, session_id), + FOREIGN KEY (app_name, user_id, session_id) + REFERENCES sessions(app_name, user_id, id) + ON DELETE CASCADE +); + +-- Create app states table +CREATE TABLE IF NOT EXISTS app_states ( + app_name VARCHAR(128) NOT NULL, + state JSONB, + update_time TIMESTAMP, + PRIMARY KEY (app_name) +); + +-- Create user states table +CREATE TABLE IF NOT EXISTS user_states ( + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + state JSONB, + update_time TIMESTAMP, + PRIMARY KEY (app_name, user_id) +); + +-- Add indexes to improve query performance + +-- Index for looking up sessions by app_name and user_id +CREATE INDEX IF NOT EXISTS idx_sessions_app_user ON sessions(app_name, user_id); + +-- Index for looking up events by session +CREATE INDEX IF NOT EXISTS idx_events_session ON events(app_name, user_id, session_id); + +-- Index for sorting events by timestamp +CREATE INDEX IF NOT EXISTS idx_events_timestamp ON events(timestamp); diff --git a/contrib/database-session-service/src/main/resources/db/migration/spanner/V1__Initial_schema.sql b/contrib/database-session-service/src/main/resources/db/migration/spanner/V1__Initial_schema.sql new file mode 100644 index 000000000..f169d0008 --- /dev/null +++ b/contrib/database-session-service/src/main/resources/db/migration/spanner/V1__Initial_schema.sql @@ -0,0 +1,59 @@ +-- V1__Initial_schema.sql for Cloud Spanner +-- Initial database schema for ADK DatabaseSessionService (v1 format) +-- This schema matches Python ADK v1 with simplified event storage using JSON + +-- Create metadata table for schema versioning +CREATE TABLE adk_internal_metadata ( + key STRING(128) NOT NULL, + value STRING(256) +) PRIMARY KEY (key); + +-- Insert schema version (1 = v1 JSON schema format, compatible with Python ADK) +INSERT INTO adk_internal_metadata (key, value) VALUES ('schema_version', '1'); + +-- Create sessions table +CREATE TABLE sessions ( + app_name STRING(128) NOT NULL, + user_id STRING(128) NOT NULL, + id STRING(128) NOT NULL, + state JSON, + create_time TIMESTAMP, + update_time TIMESTAMP +) PRIMARY KEY (app_name, user_id, id); + +-- Create events table (v1 format with event_data JSON column) +-- Note: Spanner does not support traditional FOREIGN KEY constraints with ON DELETE CASCADE +-- We avoid INTERLEAVE IN PARENT to keep the schema simpler and compatible with the DAO layer +-- Applications must handle cascade deletes manually if needed +CREATE TABLE events ( + id STRING(128) NOT NULL, + app_name STRING(128) NOT NULL, + user_id STRING(128) NOT NULL, + session_id STRING(128) NOT NULL, + invocation_id STRING(256), + timestamp TIMESTAMP, + event_data JSON +) PRIMARY KEY (id, app_name, user_id, session_id); + +-- Create app states table +CREATE TABLE app_states ( + app_name STRING(128) NOT NULL, + state JSON, + update_time TIMESTAMP +) PRIMARY KEY (app_name); + +-- Create user states table +CREATE TABLE user_states ( + app_name STRING(128) NOT NULL, + user_id STRING(128) NOT NULL, + state JSON, + update_time TIMESTAMP +) PRIMARY KEY (app_name, user_id); + +-- Add indexes to improve query performance + +-- Index for looking up sessions by app_name and user_id +CREATE INDEX idx_sessions_app_user ON sessions(app_name, user_id); + +-- Index for sorting events by timestamp +CREATE INDEX idx_events_timestamp ON events(timestamp); diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/AppUserStateLockingTest.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/AppUserStateLockingTest.java new file mode 100644 index 000000000..61d6505ee --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/AppUserStateLockingTest.java @@ -0,0 +1,298 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.google.adk.events.Event; +import com.google.adk.events.EventActions; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import java.time.Instant; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import org.flywaydb.core.Flyway; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +/** + * Tests for pessimistic locking on app_states and user_states tables. + * + *

This test verifies that concurrent updates to app-level and user-level state from multiple + * sessions do not result in lost updates. Without pessimistic locking, concurrent read-modify-write + * operations can overwrite each other's changes. + */ +public class AppUserStateLockingTest { + + private static final String TEST_DB_URL = + "jdbc:h2:mem:app_user_locking_test;DB_CLOSE_DELAY=-1;USER=sa;PASSWORD="; + private static final String TEST_APP_NAME = "app-user-lock-test"; + private static final String TEST_USER_ID = "test-user"; + + private DatabaseSessionService sessionService; + + @BeforeEach + public void setUp() { + Flyway flyway = + Flyway.configure() + .dataSource(TEST_DB_URL, null, null) + .locations("classpath:db/migration/h2") + .cleanDisabled(false) + .load(); + flyway.clean(); + flyway.migrate(); + + sessionService = new DatabaseSessionService(TEST_DB_URL); + } + + @AfterEach + public void tearDown() { + if (sessionService != null) { + sessionService.close(); + } + } + + /** + * Tests that concurrent updates to app state from multiple threads on the SAME session preserve + * all changes. + * + *

Scenario: - 10 threads concurrently append events to the SAME session - Each event sets a + * unique key in app state - Expected: All 10 keys present - Without locking on app_states: some + * keys would be lost + * + *

Note: This tests the real-world pattern where events carry state deltas, not + * read-modify-write. + */ + @Test + public void testAppStateConcurrentUpdates_noLostUpdates() throws InterruptedException { + int threadCount = 10; + ExecutorService executor = Executors.newFixedThreadPool(threadCount); + CountDownLatch latch = new CountDownLatch(threadCount); + + // Create initial session + String sharedSessionId = "shared-session"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sharedSessionId) + .blockingGet(); + + // Each thread appends event with a unique app state key + for (int i = 0; i < threadCount; i++) { + final int threadNum = i; + executor.submit( + () -> { + try { + // Each thread sets its own unique key in app state + ConcurrentMap stateDelta = new ConcurrentHashMap<>(); + stateDelta.put("app:thread_" + threadNum, threadNum); + + Event event = + Event.builder() + .id("event-" + threadNum) + .author("thread-" + threadNum) + .content(Content.fromParts(Part.fromText("Increment app counter"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(EventActions.builder().stateDelta(stateDelta).build()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sharedSessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + + } catch (Exception e) { + throw new RuntimeException("Thread " + threadNum + " failed", e); + } finally { + latch.countDown(); + } + }); + } + + assertTrue(latch.await(60, TimeUnit.SECONDS), "Threads did not complete in time"); + executor.shutdown(); + + // Verify final counter value + Session finalSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sharedSessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(finalSession); + + // Check that all thread keys are present + for (int i = 0; i < threadCount; i++) { + String key = "app:thread_" + i; + assertTrue(finalSession.state().containsKey(key), "app:thread_" + i + " should exist"); + assertEquals(i, finalSession.state().get(key), "app:thread_" + i + " should equal " + i); + } + } + + /** + * Tests that concurrent updates to user state from multiple sessions preserve all changes. + * + *

Scenario: - Same user has 10 different sessions (e.g., phone, laptop, tablet) - Each session + * concurrently increments user:notification_count - Expected final value: 10 (all updates + * preserved) - Without locking: final value would be < 10 (lost updates) + */ + @Test + public void testUserStateConcurrentUpdates_noLostUpdates() throws InterruptedException { + int threadCount = 10; + ExecutorService executor = Executors.newFixedThreadPool(threadCount); + CountDownLatch latch = new CountDownLatch(threadCount); + + // Create initial session + String sharedSessionId = "shared-session"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sharedSessionId) + .blockingGet(); + + // Each thread appends event with a unique user state key + for (int i = 0; i < threadCount; i++) { + final int threadNum = i; + executor.submit( + () -> { + try { + // Each thread sets its own unique key in user state + ConcurrentMap stateDelta = new ConcurrentHashMap<>(); + stateDelta.put("user:thread_" + threadNum, threadNum); + + Event event = + Event.builder() + .id("notif-" + threadNum) + .author("device-" + threadNum) + .content(Content.fromParts(Part.fromText("New notification"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(EventActions.builder().stateDelta(stateDelta).build()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sharedSessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + + } catch (Exception e) { + throw new RuntimeException("Thread " + threadNum + " failed", e); + } finally { + latch.countDown(); + } + }); + } + + assertTrue(latch.await(60, TimeUnit.SECONDS), "Threads did not complete in time"); + executor.shutdown(); + + // Verify final notification count + Session finalSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sharedSessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(finalSession); + + // Check that all thread keys are present + for (int i = 0; i < threadCount; i++) { + String key = "user:thread_" + i; + assertTrue(finalSession.state().containsKey(key), "user:thread_" + i + " should exist"); + assertEquals(i, finalSession.state().get(key), "user:thread_" + i + " should equal " + i); + } + } + + /** + * Tests that concurrent updates to both app and user state work correctly. + * + *

Scenario: - 5 sessions concurrently update both app:total_requests and user:request_count - + * Tests that locks on app_states and user_states don't deadlock + */ + @Test + public void testConcurrentAppAndUserStateUpdates() throws InterruptedException { + int threadCount = 5; + ExecutorService executor = Executors.newFixedThreadPool(threadCount); + CountDownLatch latch = new CountDownLatch(threadCount); + + // Create initial session + String sharedSessionId = "shared-session"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sharedSessionId) + .blockingGet(); + + // Each thread appends event with both app and user state updates + for (int i = 0; i < threadCount; i++) { + final int threadNum = i; + executor.submit( + () -> { + try { + // Each thread sets unique keys in both app and user state + ConcurrentMap stateDelta = new ConcurrentHashMap<>(); + stateDelta.put("app:req_" + threadNum, threadNum); + stateDelta.put("user:req_" + threadNum, threadNum); + + Event event = + Event.builder() + .id("req-" + threadNum) + .author("thread-" + threadNum) + .content(Content.fromParts(Part.fromText("API request"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(EventActions.builder().stateDelta(stateDelta).build()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sharedSessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + + } catch (Exception e) { + throw new RuntimeException("Thread " + threadNum + " failed", e); + } finally { + latch.countDown(); + } + }); + } + + assertTrue(latch.await(60, TimeUnit.SECONDS), "Threads did not complete in time"); + executor.shutdown(); + + // Verify both counters + Session finalSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sharedSessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(finalSession); + + // Check that all app and user keys are present + for (int i = 0; i < threadCount; i++) { + String appKey = "app:req_" + i; + String userKey = "user:req_" + i; + assertTrue(finalSession.state().containsKey(appKey), appKey + " should exist"); + assertTrue(finalSession.state().containsKey(userKey), userKey + " should exist"); + assertEquals(i, finalSession.state().get(appKey), appKey + " should equal " + i); + assertEquals(i, finalSession.state().get(userKey), userKey + " should equal " + i); + } + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/AppendEventRaceConditionTest.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/AppendEventRaceConditionTest.java new file mode 100644 index 000000000..aa0d676a4 --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/AppendEventRaceConditionTest.java @@ -0,0 +1,300 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.sessions; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.adk.events.Event; +import com.google.adk.events.EventActions; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import io.reactivex.rxjava3.core.Flowable; +import java.time.Instant; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Test that demonstrates the race condition between event emission and persistence. + * + *

This test proves that using {@code doOnNext()} to append events creates a race condition where + * events flow downstream before being persisted to the database, while using {@code flatMap()} + * correctly waits for persistence to complete. + */ +@RunWith(JUnit4.class) +public class AppendEventRaceConditionTest { + + private DatabaseSessionService sessionService; + private static final String APP_NAME = "race-test-app"; + private static final String USER_ID = "race-test-user"; + + @Before + public void setUp() throws Exception { + sessionService = + new DatabaseSessionService("jdbc:h2:mem:race_test_db;DB_CLOSE_DELAY=-1;USER=sa;PASSWORD="); + } + + @After + public void tearDown() { + if (sessionService != null) { + sessionService.close(); + } + } + + /** + * This test demonstrates the RACE CONDITION with doOnNext(). + * + *

Timeline: T=0ms: Event emitted T=1ms: doOnNext() fires appendEvent() (doesn't wait!) T=2ms: + * Event flows downstream immediately T=5ms: We query listEvents() T=6ms: Query reads database + * Result: Event might NOT be in database yet! ← RACE CONDITION T=100ms: Database write finally + * completes + */ + @Test + public void testDoOnNext_hasRaceCondition() throws Exception { + Session testSession = + sessionService + .createSession(APP_NAME, USER_ID, new ConcurrentHashMap<>(), "session-doOnNext") + .blockingGet(); + + AtomicInteger eventsSeenInQuery = new AtomicInteger(0); + AtomicBoolean appendStarted = new AtomicBoolean(false); + CountDownLatch queryLatch = new CountDownLatch(1); + + Event testEvent = + Event.builder() + .id("race-event-1") + .invocationId("inv-1") + .author("test-agent") + .content(Content.builder().parts(Part.builder().text("Test").build()).build()) + .actions(EventActions.builder().build()) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Flowable.just(testEvent) + .doOnNext( + event -> { + appendStarted.set(true); + sessionService.appendEvent(testSession, event); + }) + .doOnNext( + event -> { + try { + Thread.sleep(50); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + + List events = + sessionService + .listEvents(APP_NAME, USER_ID, testSession.id()) + .blockingGet() + .events(); + eventsSeenInQuery.set(events.size()); + queryLatch.countDown(); + }) + .blockingSubscribe(); + + queryLatch.await(5, TimeUnit.SECONDS); + + System.out.println( + "doOnNext() test - Events seen in query: " + + eventsSeenInQuery.get() + + " (expected 0 or 1 due to race)"); + } + + /** + * This test demonstrates the CORRECT BEHAVIOR with flatMap(). + * + *

Timeline: T=0ms: Event emitted T=1ms: flatMap() calls appendEvent() T=2ms: Waits for + * appendEvent() Single to complete T=100ms: Database write completes T=101ms: Event flows + * downstream T=102ms: We query listEvents() T=103ms: Query reads database Result: Event IS in + * database! ← CORRECT + */ + @Test + public void testFlatMap_waitsForPersistence() throws Exception { + Session testSession = + sessionService + .createSession(APP_NAME, USER_ID, new ConcurrentHashMap<>(), "session-flatMap") + .blockingGet(); + + AtomicInteger eventsSeenInQuery = new AtomicInteger(0); + CountDownLatch queryLatch = new CountDownLatch(1); + + Event testEvent = + Event.builder() + .id("race-event-2") + .invocationId("inv-2") + .author("test-agent") + .content(Content.builder().parts(Part.builder().text("Test").build()).build()) + .actions(EventActions.builder().build()) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Flowable.just(testEvent) + .flatMap( + event -> + sessionService + .appendEvent(testSession, event) + .toFlowable() + .onErrorResumeNext( + error -> { + System.err.println("Failed to append event: " + error.getMessage()); + return Flowable.just(event); + })) + .doOnNext( + event -> { + List events = + sessionService + .listEvents(APP_NAME, USER_ID, testSession.id()) + .blockingGet() + .events(); + eventsSeenInQuery.set(events.size()); + queryLatch.countDown(); + }) + .blockingSubscribe(); + + queryLatch.await(5, TimeUnit.SECONDS); + + System.out.println( + "flatMap() test - Events seen in query: " + + eventsSeenInQuery.get() + + " (expected 1 - always present)"); + + assertThat(eventsSeenInQuery.get()).isEqualTo(1); + } + + /** + * This test runs multiple iterations to increase the chance of catching the race condition. + * + *

With doOnNext(), we expect to see the race condition manifest as inconsistent query results. + * With flatMap(), we expect 100% consistency. + */ + @Test + public void testRaceCondition_multipleIterations() throws Exception { + int iterations = 10; + int doOnNextMisses = 0; + int flatMapMisses = 0; + + for (int i = 0; i < iterations; i++) { + final int iteration = i; + Session session = + sessionService + .createSession( + APP_NAME, "user-" + iteration, new ConcurrentHashMap<>(), "session-" + iteration) + .blockingGet(); + + Event event = + Event.builder() + .id("event-" + iteration) + .invocationId("inv-" + iteration) + .author("test") + .content( + Content.builder().parts(Part.builder().text("Test " + iteration).build()).build()) + .actions(EventActions.builder().build()) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + AtomicInteger doOnNextCount = new AtomicInteger(0); + CountDownLatch doOnNextLatch = new CountDownLatch(1); + + Flowable.just(event) + .doOnNext(e -> sessionService.appendEvent(session, e)) + .delay(10, TimeUnit.MILLISECONDS) + .doOnNext( + e -> { + int count = + sessionService + .listEvents(APP_NAME, "user-" + iteration, "session-" + iteration) + .blockingGet() + .events() + .size(); + doOnNextCount.set(count); + doOnNextLatch.countDown(); + }) + .blockingSubscribe(); + + doOnNextLatch.await(2, TimeUnit.SECONDS); + if (doOnNextCount.get() == 0) { + doOnNextMisses++; + } + + Session session2 = + sessionService + .createSession( + APP_NAME, + "user2-" + iteration, + new ConcurrentHashMap<>(), + "session2-" + iteration) + .blockingGet(); + + Event event2 = + Event.builder() + .id("event2-" + iteration) + .invocationId("inv2-" + iteration) + .author("test") + .content( + Content.builder() + .parts(Part.builder().text("Test2 " + iteration).build()) + .build()) + .actions(EventActions.builder().build()) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + AtomicInteger flatMapCount = new AtomicInteger(0); + CountDownLatch flatMapLatch = new CountDownLatch(1); + + Flowable.just(event2) + .flatMap( + e -> + sessionService + .appendEvent(session2, e) + .toFlowable() + .onErrorResumeNext(err -> Flowable.just(e))) + .doOnNext( + e -> { + int count = + sessionService + .listEvents(APP_NAME, "user2-" + iteration, "session2-" + iteration) + .blockingGet() + .events() + .size(); + flatMapCount.set(count); + flatMapLatch.countDown(); + }) + .blockingSubscribe(); + + flatMapLatch.await(2, TimeUnit.SECONDS); + if (flatMapCount.get() == 0) { + flatMapMisses++; + } + } + + System.out.println("Race condition test results over " + iterations + " iterations:"); + System.out.println( + " doOnNext() misses: " + doOnNextMisses + " (race condition manifestations)"); + System.out.println(" flatMap() misses: " + flatMapMisses + " (should be 0)"); + + assertThat(flatMapMisses).isEqualTo(0); + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/AppendEventTest.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/AppendEventTest.java new file mode 100644 index 000000000..ec7ab3f7e --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/AppendEventTest.java @@ -0,0 +1,259 @@ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.*; + +import com.google.adk.events.Event; +import com.google.adk.events.EventActions; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import java.time.Instant; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class AppendEventTest { + + private static final String TEST_DB_URL = + "jdbc:h2:mem:append_failure_test;DB_CLOSE_DELAY=-1;MODE=PostgreSQL"; + private static final String TEST_APP_NAME = "failure-test-app"; + private static final String TEST_USER_ID = "failure-test-user"; + + private DatabaseSessionService sessionService; + + @BeforeEach + public void setUp() { + sessionService = new DatabaseSessionService(TEST_DB_URL); + } + + @AfterEach + public void tearDown() { + if (sessionService != null) { + sessionService.close(); + } + } + + @Test + public void testAppendEvent_dbWriteFailsDueToClosedService_memoryUnchanged() { + String sessionId = "db-fail-test"; + Session session = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + assertEquals(0, session.events().size()); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Test event"))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + sessionService.close(); + + try { + sessionService.appendEvent(session, event).blockingGet(); + fail("Expected IllegalStateException when service is closed"); + } catch (IllegalStateException e) { + assertTrue(e.getMessage().contains("closed")); + } + + assertEquals(0, session.events().size(), "Memory should remain unchanged when DB write fails"); + } + + @Test + public void testAppendEvent_nonExistentSession_throwsException() { + String sessionId = "non-existent-session"; + Session fakeSession = + Session.builder(sessionId) + .appName(TEST_APP_NAME) + .userId(TEST_USER_ID) + .state(new ConcurrentHashMap<>()) + .build(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Test event"))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + try { + sessionService.appendEvent(fakeSession, event).blockingGet(); + fail("Expected SessionNotFoundException for non-existent session"); + } catch (SessionNotFoundException e) { + assertTrue(e.getMessage().contains("Session not found")); + } + } + + @Test + public void testAppendEvent_validStateDelta_persistsCorrectly() throws Exception { + String sessionId = "valid-state-test"; + Session session = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + ConcurrentHashMap delta = new ConcurrentHashMap<>(); + delta.put(State.APP_PREFIX + "valid_key", "valid_value"); + + EventActions actions = EventActions.builder().stateDelta(delta).build(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Test"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(actions) + .build(); + + Event appendedEvent = sessionService.appendEvent(session, event).blockingGet(); + + assertNotNull(appendedEvent); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertEquals("valid_value", retrieved.state().get(State.APP_PREFIX + "valid_key")); + assertEquals(1, retrieved.events().size()); + } + + @Test + public void testAppendEvent_successfulAppendAndRetrieval() throws Exception { + String sessionId = "success-test"; + Session session = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Test"))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Event appended = sessionService.appendEvent(session, event).blockingGet(); + assertNotNull(appended); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertEquals(1, retrieved.events().size()); + } + + @Test + public void testAppendEvent_concurrentModificationWithRollback() throws Exception { + String sessionId = "concurrent-rollback-test"; + + ConcurrentHashMap initialState = new ConcurrentHashMap<>(); + initialState.put(State.APP_PREFIX + "counter", 0); + + Session session = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, initialState, sessionId) + .blockingGet(); + + ConcurrentHashMap delta1 = new ConcurrentHashMap<>(); + delta1.put(State.APP_PREFIX + "counter", 1); + + EventActions actions1 = EventActions.builder().stateDelta(delta1).build(); + + Event event1 = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Event 1"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(actions1) + .build(); + + sessionService.appendEvent(session, event1).blockingGet(); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertEquals(1, retrieved.state().get(State.APP_PREFIX + "counter")); + assertEquals(1, retrieved.events().size()); + } + + @Test + public void testAppendEvent_multipleConcurrentFailures() throws Exception { + String sessionId = "multi-failure-test"; + Session session = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + Event event1 = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Event 1"))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + sessionService.appendEvent(session, event1).blockingGet(); + + assertEquals(1, session.events().size()); + + sessionService.close(); + + Event event2 = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Event 2"))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + try { + sessionService.appendEvent(session, event2).blockingGet(); + fail("Expected IllegalStateException after close"); + } catch (IllegalStateException e) { + assertTrue(e.getMessage().contains("closed")); + } + + assertEquals(1, session.events().size(), "Event count should remain at 1 after failed append"); + } + + @Test + public void testAppendEvent_errorLogging() { + String sessionId = "error-logging-test"; + Session session = + Session.builder(sessionId) + .appName(TEST_APP_NAME) + .userId(TEST_USER_ID) + .state(new ConcurrentHashMap<>()) + .build(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Test"))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + try { + sessionService.appendEvent(session, event).blockingGet(); + fail("Should throw SessionNotFoundException"); + } catch (SessionNotFoundException e) { + assertTrue( + e.getMessage().contains("Session not found"), "Error message should be descriptive"); + } + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/ConcurrentSessionOperationsTest.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/ConcurrentSessionOperationsTest.java new file mode 100644 index 000000000..7ffb162ab --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/ConcurrentSessionOperationsTest.java @@ -0,0 +1,498 @@ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +import com.google.adk.events.Event; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class ConcurrentSessionOperationsTest { + + private static final String TEST_DB_URL = + "jdbc:h2:mem:concurrency_test;DB_CLOSE_DELAY=-1;MODE=PostgreSQL"; + private static final String TEST_APP_NAME = "concurrency-test-app"; + private static final String TEST_USER_ID = "concurrency-test-user"; + + private DatabaseSessionService sessionService; + + @BeforeEach + public void setUp() { + sessionService = new DatabaseSessionService(TEST_DB_URL); + } + + @AfterEach + public void tearDown() { + if (sessionService != null) { + sessionService.close(); + } + } + + @Test + public void testConcurrentEventAppends() throws InterruptedException { + String sessionId = "concurrent-append-test"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + int threadCount = 10; + int eventsPerThread = 5; + ExecutorService executor = Executors.newFixedThreadPool(threadCount); + CountDownLatch latch = new CountDownLatch(threadCount); + + for (int t = 0; t < threadCount; t++) { + final int threadId = t; + executor.submit( + () -> { + try { + for (int i = 0; i < eventsPerThread; i++) { + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("thread-" + threadId) + .content(Content.fromParts(Part.fromText("Event from thread " + threadId))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + TimeUnit.MILLISECONDS.sleep(10); + } + } catch (Exception e) { + e.printStackTrace(); + } finally { + latch.countDown(); + } + }); + } + + latch.await(30, TimeUnit.SECONDS); + executor.shutdown(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(session); + assertEquals(threadCount * eventsPerThread, session.events().size()); + } + + @Test + public void testConcurrentSessionCreations() throws InterruptedException { + int threadCount = 10; + ExecutorService executor = Executors.newFixedThreadPool(threadCount); + CountDownLatch latch = new CountDownLatch(threadCount); + List sessionIds = new ArrayList<>(); + + for (int t = 0; t < threadCount; t++) { + final int threadId = t; + executor.submit( + () -> { + try { + String sessionId = "session-" + threadId; + sessionIds.add(sessionId); + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put("thread", threadId); + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId) + .blockingGet(); + } finally { + latch.countDown(); + } + }); + } + + latch.await(30, TimeUnit.SECONDS); + executor.shutdown(); + + for (String sessionId : sessionIds) { + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + assertNotNull(session); + } + } + + @Test + public void testConcurrentReadsAndWrites() throws InterruptedException { + String sessionId = "read-write-test"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + int readerCount = 5; + int writerCount = 5; + ExecutorService executor = Executors.newFixedThreadPool(readerCount + writerCount); + CountDownLatch latch = new CountDownLatch(readerCount + writerCount); + + for (int i = 0; i < writerCount; i++) { + final int writerId = i; + executor.submit( + () -> { + try { + for (int j = 0; j < 3; j++) { + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("writer-" + writerId) + .content(Content.fromParts(Part.fromText("Event"))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + TimeUnit.MILLISECONDS.sleep(20); + } + } catch (Exception e) { + e.printStackTrace(); + } finally { + latch.countDown(); + } + }); + } + + for (int i = 0; i < readerCount; i++) { + executor.submit( + () -> { + try { + for (int j = 0; j < 5; j++) { + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + TimeUnit.MILLISECONDS.sleep(10); + } + } catch (Exception e) { + e.printStackTrace(); + } finally { + latch.countDown(); + } + }); + } + + latch.await(60, TimeUnit.SECONDS); + executor.shutdown(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(session); + assertEquals(writerCount * 3, session.events().size()); + } + + @Test + public void testConcurrentAppStateUpdates() throws InterruptedException { + int threadCount = 5; + ExecutorService executor = Executors.newFixedThreadPool(threadCount); + CountDownLatch latch = new CountDownLatch(threadCount); + + for (int t = 0; t < threadCount; t++) { + final int threadId = t; + executor.submit( + () -> { + try { + String sessionId = "app-state-" + threadId; + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put("_app_counter", threadId); + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId) + .blockingGet(); + } finally { + latch.countDown(); + } + }); + } + + latch.await(30, TimeUnit.SECONDS); + executor.shutdown(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, "app-state-0", Optional.empty()) + .blockingGet(); + + assertNotNull(session); + assertNotNull(session.state().get("_app_counter")); + } + + @Test + public void testConcurrentCreateSessionsWithSameAppName_noStateCorruption() + throws InterruptedException { + int threadCount = 10; + ExecutorService executor = Executors.newFixedThreadPool(threadCount); + CountDownLatch latch = new CountDownLatch(threadCount); + + for (int t = 0; t < threadCount; t++) { + final int threadId = t; + TimeUnit.MILLISECONDS.sleep(10); + executor.submit( + () -> { + try { + String sessionId = "concurrent-app-state-" + threadId; + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put(State.APP_PREFIX + "key_" + threadId, "value_" + threadId); + + sessionService + .createSession(TEST_APP_NAME, "user-" + threadId, state, sessionId) + .blockingGet(); + } catch (Exception e) { + e.printStackTrace(); + } finally { + latch.countDown(); + } + }); + } + + latch.await(30, TimeUnit.SECONDS); + executor.shutdown(); + + Session verifySession = + sessionService + .getSession(TEST_APP_NAME, "user-0", "concurrent-app-state-0", Optional.empty()) + .blockingGet(); + + assertNotNull(verifySession); + for (int i = 0; i < threadCount; i++) { + String key = State.APP_PREFIX + "key_" + i; + assertEquals( + "value_" + i, + verifySession.state().get(key), + "App state should contain all keys from concurrent creates without corruption"); + } + } + + @Test + public void testConcurrentCreateSessionsWithSameUser_noUserStateCorruption() + throws InterruptedException { + int threadCount = 10; + ExecutorService executor = Executors.newFixedThreadPool(threadCount); + CountDownLatch latch = new CountDownLatch(threadCount); + + for (int t = 0; t < threadCount; t++) { + final int threadId = t; + TimeUnit.MILLISECONDS.sleep(10); + executor.submit( + () -> { + try { + String sessionId = "concurrent-user-state-" + threadId; + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put(State.USER_PREFIX + "pref_" + threadId, threadId * 100); + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId) + .blockingGet(); + } catch (Exception e) { + e.printStackTrace(); + } finally { + latch.countDown(); + } + }); + } + + latch.await(30, TimeUnit.SECONDS); + executor.shutdown(); + + Session verifySession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, "concurrent-user-state-0", Optional.empty()) + .blockingGet(); + + assertNotNull(verifySession); + for (int i = 0; i < threadCount; i++) { + String key = State.USER_PREFIX + "pref_" + i; + assertEquals( + i * 100, + verifySession.state().get(key), + "User state should contain all keys from concurrent creates without corruption"); + } + } + + @Test + public void testConcurrentCreateSessionsWithMixedStateUpdates() throws InterruptedException { + int threadCount = 8; + ExecutorService executor = Executors.newFixedThreadPool(threadCount); + CountDownLatch latch = new CountDownLatch(threadCount); + + for (int t = 0; t < threadCount; t++) { + final int threadId = t; + executor.submit( + () -> { + try { + String sessionId = "mixed-state-" + threadId; + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put(State.APP_PREFIX + "shared_app_key", "app_value_" + threadId); + state.put(State.USER_PREFIX + "user_pref_" + threadId, threadId); + state.put("session_local", "local_" + threadId); + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId) + .blockingGet(); + TimeUnit.MILLISECONDS.sleep(10); + } catch (Exception e) { + e.printStackTrace(); + } finally { + latch.countDown(); + } + }); + } + + latch.await(30, TimeUnit.SECONDS); + executor.shutdown(); + + Session session0 = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, "mixed-state-0", Optional.empty()) + .blockingGet(); + + assertNotNull(session0); + assertNotNull(session0.state().get(State.APP_PREFIX + "shared_app_key")); + + for (int i = 0; i < threadCount; i++) { + assertEquals(i, session0.state().get(State.USER_PREFIX + "user_pref_" + i), "User pref " + i); + } + + for (int i = 0; i < threadCount; i++) { + Session sessionI = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, "mixed-state-" + i, Optional.empty()) + .blockingGet(); + assertEquals("local_" + i, sessionI.state().get("session_local")); + } + } + + @Test + public void testConcurrentCreateSessionsWithPreExistingAppState_noLag() + throws InterruptedException { + ConcurrentHashMap initialState = new ConcurrentHashMap<>(); + initialState.put(State.APP_PREFIX + "key_initial", "initial_value"); + sessionService + .createSession(TEST_APP_NAME, "initial-user", initialState, "initial-session") + .blockingGet(); + + int threadCount = 10; + ExecutorService executor = Executors.newFixedThreadPool(threadCount); + CountDownLatch latch = new CountDownLatch(threadCount); + + for (int t = 0; t < threadCount; t++) { + final int threadId = t; + executor.submit( + () -> { + try { + String sessionId = "pre-existing-app-state-" + threadId; + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put(State.APP_PREFIX + "key_" + threadId, "value_" + threadId); + + sessionService + .createSession(TEST_APP_NAME, "user-" + threadId, state, sessionId) + .blockingGet(); + } catch (Exception e) { + e.printStackTrace(); + } finally { + latch.countDown(); + } + }); + } + + latch.await(30, TimeUnit.SECONDS); + executor.shutdown(); + + Session verifySession = + sessionService + .getSession(TEST_APP_NAME, "user-0", "pre-existing-app-state-0", Optional.empty()) + .blockingGet(); + + assertNotNull(verifySession); + assertEquals( + "initial_value", + verifySession.state().get(State.APP_PREFIX + "key_initial"), + "Initial app state key should be preserved"); + for (int i = 0; i < threadCount; i++) { + String key = State.APP_PREFIX + "key_" + i; + assertEquals( + "value_" + i, + verifySession.state().get(key), + "App state should contain all keys when row pre-exists (SELECT FOR UPDATE works)"); + } + } + + @Test + public void testConcurrentCreateSessionsWithPreExistingUserState_noLag() + throws InterruptedException { + ConcurrentHashMap initialState = new ConcurrentHashMap<>(); + initialState.put(State.USER_PREFIX + "pref_initial", -1); + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, initialState, "initial-session") + .blockingGet(); + + int threadCount = 10; + ExecutorService executor = Executors.newFixedThreadPool(threadCount); + CountDownLatch latch = new CountDownLatch(threadCount); + + for (int t = 0; t < threadCount; t++) { + final int threadId = t; + executor.submit( + () -> { + try { + String sessionId = "pre-existing-user-state-" + threadId; + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put(State.USER_PREFIX + "pref_" + threadId, threadId * 100); + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId) + .blockingGet(); + } catch (Exception e) { + e.printStackTrace(); + } finally { + latch.countDown(); + } + }); + } + + latch.await(30, TimeUnit.SECONDS); + executor.shutdown(); + + Session verifySession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, "pre-existing-user-state-0", Optional.empty()) + .blockingGet(); + + assertNotNull(verifySession); + assertEquals( + -1, + verifySession.state().get(State.USER_PREFIX + "pref_initial"), + "Initial user state key should be preserved"); + for (int i = 0; i < threadCount; i++) { + String key = State.USER_PREFIX + "pref_" + i; + assertEquals( + i * 100, + verifySession.state().get(key), + "User state should contain all keys when row pre-exists (SELECT FOR UPDATE works)"); + } + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/ContentSerializationTest.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/ContentSerializationTest.java new file mode 100644 index 000000000..d09fdf8b7 --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/ContentSerializationTest.java @@ -0,0 +1,522 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.google.adk.events.Event; +import com.google.genai.types.Content; +import com.google.genai.types.ExecutableCode; +import com.google.genai.types.FileData; +import com.google.genai.types.FunctionCall; +import com.google.genai.types.FunctionResponse; +import com.google.genai.types.Part; +import java.time.Instant; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import org.flywaydb.core.Flyway; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class ContentSerializationTest { + + private static final String TEST_DB_URL = + "jdbc:h2:mem:testdb_content;DB_CLOSE_DELAY=-1;USER=sa;PASSWORD="; + private static final String TEST_APP_NAME = "content-test-app"; + private static final String TEST_USER_ID = "content-test-user"; + + private DatabaseSessionService sessionService; + + @BeforeEach + public void setUp() { + Flyway flyway = + Flyway.configure() + .dataSource(TEST_DB_URL, null, null) + .locations("classpath:db/migration/h2") + .cleanDisabled(false) + .load(); + flyway.clean(); + flyway.migrate(); + + sessionService = new DatabaseSessionService(TEST_DB_URL); + } + + @AfterEach + public void tearDown() { + sessionService.close(); + } + + @Test + public void testTextPartRoundTrip() { + String sessionId = "text-part-test"; + String testText = "Hello, world! This is a test message."; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + Event originalEvent = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test-author") + .content(Content.fromParts(Part.fromText(testText))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, originalEvent).blockingGet(); + + Session retrievedSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(retrievedSession); + assertEquals(1, retrievedSession.events().size()); + + Event retrievedEvent = retrievedSession.events().get(0); + assertNotNull(retrievedEvent.content()); + assertTrue(retrievedEvent.content().isPresent()); + + Content content = retrievedEvent.content().get(); + assertNotNull(content.parts()); + assertTrue(content.parts().isPresent()); + + List parts = content.parts().get(); + assertEquals(1, parts.size()); + + Part part = parts.get(0); + assertTrue(part.text().isPresent()); + assertEquals(testText, part.text().get()); + } + + @Test + public void testFunctionCallPartRoundTrip() { + String sessionId = "function-call-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + FunctionCall functionCall = + FunctionCall.builder() + .name("get_weather") + .args(Map.of("location", "San Francisco", "unit", "celsius")) + .id("call-123") + .build(); + + Event originalEvent = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("model") + .content(Content.fromParts(Part.builder().functionCall(functionCall).build())) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, originalEvent).blockingGet(); + + Session retrievedSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(retrievedSession); + assertEquals(1, retrievedSession.events().size()); + + Event retrievedEvent = retrievedSession.events().get(0); + Content content = retrievedEvent.content().get(); + Part part = content.parts().get().get(0); + + assertTrue(part.functionCall().isPresent()); + FunctionCall retrievedCall = part.functionCall().get(); + + assertEquals("get_weather", retrievedCall.name().get()); + assertEquals("call-123", retrievedCall.id().get()); + + Map retrievedArgs = retrievedCall.args().get(); + assertEquals("San Francisco", retrievedArgs.get("location")); + assertEquals("celsius", retrievedArgs.get("unit")); + } + + @Test + public void testFileDataPartRoundTrip() { + String sessionId = "file-data-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + FileData fileData = + FileData.builder() + .fileUri("gs://bucket/path/to/file.pdf") + .mimeType("application/pdf") + .build(); + + Event originalEvent = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("user") + .content(Content.fromParts(Part.builder().fileData(fileData).build())) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, originalEvent).blockingGet(); + + Session retrievedSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + Event retrievedEvent = retrievedSession.events().get(0); + Part part = retrievedEvent.content().get().parts().get().get(0); + + assertTrue(part.fileData().isPresent()); + FileData retrievedFileData = part.fileData().get(); + + assertEquals("gs://bucket/path/to/file.pdf", retrievedFileData.fileUri().get()); + assertEquals("application/pdf", retrievedFileData.mimeType().get()); + } + + @Test + public void testFunctionResponsePartRoundTrip() { + String sessionId = "function-response-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + FunctionResponse functionResponse = + FunctionResponse.builder() + .name("get_weather") + .response(Map.of("temperature", 72, "conditions", "sunny")) + .id("call-123") + .build(); + + Event originalEvent = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("tool") + .content(Content.fromParts(Part.builder().functionResponse(functionResponse).build())) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, originalEvent).blockingGet(); + + Session retrievedSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + Event retrievedEvent = retrievedSession.events().get(0); + Part part = retrievedEvent.content().get().parts().get().get(0); + + assertTrue(part.functionResponse().isPresent()); + FunctionResponse retrievedResponse = part.functionResponse().get(); + + assertEquals("get_weather", retrievedResponse.name().get()); + assertEquals("call-123", retrievedResponse.id().get()); + + Map responseData = retrievedResponse.response().get(); + assertEquals(72, responseData.get("temperature")); + assertEquals("sunny", responseData.get("conditions")); + } + + @Test + public void testExecutableCodePartRoundTrip() { + String sessionId = "executable-code-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + ExecutableCode executableCode = ExecutableCode.builder().code("print('Hello, World!')").build(); + + Event originalEvent = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("model") + .content(Content.fromParts(Part.builder().executableCode(executableCode).build())) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, originalEvent).blockingGet(); + + Session retrievedSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + Event retrievedEvent = retrievedSession.events().get(0); + Part part = retrievedEvent.content().get().parts().get().get(0); + + assertTrue(part.executableCode().isPresent()); + ExecutableCode retrievedCode = part.executableCode().get(); + + assertEquals("print('Hello, World!')", retrievedCode.code().get()); + } + + @Test + public void testMixedPartsInSingleEvent() { + String sessionId = "mixed-parts-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + Part textPart = Part.fromText("Let me call a function:"); + Part functionCallPart = + Part.builder() + .functionCall( + FunctionCall.builder() + .name("calculate") + .args(Map.of("expression", "2+2")) + .id("calc-1") + .build()) + .build(); + Part fileDataPart = + Part.builder() + .fileData( + FileData.builder().fileUri("gs://bucket/data.csv").mimeType("text/csv").build()) + .build(); + + Event originalEvent = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("model") + .content(Content.fromParts(textPart, functionCallPart, fileDataPart)) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, originalEvent).blockingGet(); + + Session retrievedSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + Event retrievedEvent = retrievedSession.events().get(0); + List parts = retrievedEvent.content().get().parts().get(); + + assertEquals(3, parts.size()); + + assertTrue(parts.get(0).text().isPresent()); + assertEquals("Let me call a function:", parts.get(0).text().get()); + + assertTrue(parts.get(1).functionCall().isPresent()); + assertEquals("calculate", parts.get(1).functionCall().get().name().get()); + + assertTrue(parts.get(2).fileData().isPresent()); + assertEquals("gs://bucket/data.csv", parts.get(2).fileData().get().fileUri().get()); + } + + /** + * Tests that a multi-turn conversation with function calls is correctly serialized and + * deserialized. This verifies the complete workflow: user message -> model function call -> tool + * response -> model final response. + * + *

IMPORTANT: Events are created with incrementing timestamps (100ms apart) to simulate + * realistic timing. In production, events naturally have different timestamps due to processing + * delays. Without timestamp separation, events with identical timestamps would have undefined + * ordering since the database only sorts by timestamp. This test previously failed intermittently + * because it created all events with Instant.now() within the same millisecond, causing + * non-deterministic ordering. + */ + @Test + public void testMultiTurnConversationWithTools() { + String sessionId = "multi-turn-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + long baseTimestamp = Instant.now().toEpochMilli(); + + Event userMessage = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("user") + .content(Content.fromParts(Part.fromText("What's the weather in Tokyo?"))) + .timestamp(baseTimestamp) + .build(); + + Event modelFunctionCall = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("model") + .content( + Content.fromParts( + Part.builder() + .functionCall( + FunctionCall.builder() + .name("get_weather") + .args(Map.of("city", "Tokyo")) + .id("weather-1") + .build()) + .build())) + .timestamp(baseTimestamp + 100) + .build(); + + Event toolResponse = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("tool") + .content( + Content.fromParts( + Part.builder() + .functionResponse( + FunctionResponse.builder() + .name("get_weather") + .response(Map.of("temp", 18, "condition", "cloudy")) + .id("weather-1") + .build()) + .build())) + .timestamp(baseTimestamp + 200) + .build(); + + Event modelFinalResponse = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("model") + .content(Content.fromParts(Part.fromText("The weather in Tokyo is 18°C and cloudy."))) + .timestamp(baseTimestamp + 300) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, userMessage).blockingGet(); + + session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, modelFunctionCall).blockingGet(); + + session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, toolResponse).blockingGet(); + + session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, modelFinalResponse).blockingGet(); + + Session retrievedSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertEquals(4, retrievedSession.events().size()); + + assertEquals( + "What's the weather in Tokyo?", + retrievedSession.events().get(0).content().get().parts().get().get(0).text().get()); + + FunctionCall retrievedCall = + retrievedSession.events().get(1).content().get().parts().get().get(0).functionCall().get(); + assertEquals("get_weather", retrievedCall.name().get()); + assertEquals("Tokyo", retrievedCall.args().get().get("city")); + + FunctionResponse retrievedResponse = + retrievedSession + .events() + .get(2) + .content() + .get() + .parts() + .get() + .get(0) + .functionResponse() + .get(); + assertEquals("get_weather", retrievedResponse.name().get()); + assertEquals(18, retrievedResponse.response().get().get("temp")); + + assertEquals( + "The weather in Tokyo is 18°C and cloudy.", + retrievedSession.events().get(3).content().get().parts().get().get(0).text().get()); + } + + @Test + public void testEmptyAndNullContent() { + String sessionId = "empty-content-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + Event emptyContentEvent = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("system") + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, emptyContentEvent).blockingGet(); + + Session retrievedSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + Event retrievedEvent = retrievedSession.events().get(0); + assertTrue( + retrievedEvent.content().isEmpty() || retrievedEvent.content().get().parts().isEmpty()); + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/CreateSessionInTransactionTest.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/CreateSessionInTransactionTest.java new file mode 100644 index 000000000..75cd82b91 --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/CreateSessionInTransactionTest.java @@ -0,0 +1,246 @@ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.*; + +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class CreateSessionInTransactionTest { + + private static final String TEST_DB_URL = "jdbc:h2:mem:create_session_test;DB_CLOSE_DELAY=-1"; + private static final String TEST_APP_NAME = "test-app"; + private static final String TEST_USER_ID = "test-user"; + + private DatabaseSessionService sessionService; + + @BeforeEach + public void setUp() { + sessionService = new DatabaseSessionService(TEST_DB_URL); + } + + @AfterEach + public void tearDown() { + if (sessionService != null) { + sessionService.close(); + } + } + + @Test + public void testStatePrefixSplitting_appPrefix() { + String sessionId = "prefix-split-app"; + ConcurrentMap state = new ConcurrentHashMap<>(); + state.put(State.APP_PREFIX + "app_key", "app_value"); + state.put("session_key", "session_value"); + + Session session = + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId).blockingGet(); + + assertNotNull(session); + assertEquals("app_value", session.state().get(State.APP_PREFIX + "app_key")); + assertEquals("session_value", session.state().get("session_key")); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertEquals("app_value", retrieved.state().get(State.APP_PREFIX + "app_key")); + assertEquals("session_value", retrieved.state().get("session_key")); + } + + @Test + public void testStatePrefixSplitting_userPrefix() { + String sessionId = "prefix-split-user"; + ConcurrentMap state = new ConcurrentHashMap<>(); + state.put(State.USER_PREFIX + "user_key", "user_value"); + state.put("session_key", "session_value"); + + Session session = + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId).blockingGet(); + + assertNotNull(session); + assertEquals("user_value", session.state().get(State.USER_PREFIX + "user_key")); + assertEquals("session_value", session.state().get("session_key")); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertEquals("user_value", retrieved.state().get(State.USER_PREFIX + "user_key")); + assertEquals("session_value", retrieved.state().get("session_key")); + } + + @Test + public void testStatePrefixSplitting_allThreeTiers() { + String sessionId = "prefix-split-all"; + ConcurrentMap state = new ConcurrentHashMap<>(); + state.put(State.APP_PREFIX + "app_setting", "global"); + state.put(State.USER_PREFIX + "user_pref", "personal"); + state.put("session_data", "private"); + + Session session = + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId).blockingGet(); + + assertNotNull(session); + assertEquals("global", session.state().get(State.APP_PREFIX + "app_setting")); + assertEquals("personal", session.state().get(State.USER_PREFIX + "user_pref")); + assertEquals("private", session.state().get("session_data")); + } + + @Test + public void testTempPrefixIgnored() { + String sessionId = "temp-ignored"; + ConcurrentMap state = new ConcurrentHashMap<>(); + state.put(State.TEMP_PREFIX + "temp_key", "should_be_ignored"); + state.put("persisted_key", "should_persist"); + + Session session = + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId).blockingGet(); + + assertNotNull(session); + assertEquals("should_persist", session.state().get("persisted_key")); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNull(retrieved.state().get(State.TEMP_PREFIX + "temp_key")); + assertEquals("should_persist", retrieved.state().get("persisted_key")); + } + + @Test + public void testUuidGeneratedWhenSessionIdNull() { + ConcurrentMap state = new ConcurrentHashMap<>(); + state.put("key", "value"); + + Session session1 = + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state, null).blockingGet(); + Session session2 = + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state, null).blockingGet(); + + assertNotNull(session1.id()); + assertNotNull(session2.id()); + assertNotEquals(session1.id(), session2.id()); + assertTrue(session1.id().length() > 0); + assertTrue(session2.id().length() > 0); + } + + @Test + public void testUuidGeneratedWhenSessionIdEmpty() { + ConcurrentMap state = new ConcurrentHashMap<>(); + state.put("key", "value"); + + Session session = + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state, "").blockingGet(); + + assertNotNull(session.id()); + assertNotEquals("", session.id()); + assertTrue(session.id().length() > 0); + } + + @Test + public void testNullStateHandled() { + String sessionId = "null-state"; + + Session session = + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, null, sessionId).blockingGet(); + + assertNotNull(session); + assertNotNull(session.state()); + } + + @Test + public void testEmptyStateHandled() { + String sessionId = "empty-state"; + ConcurrentMap state = new ConcurrentHashMap<>(); + + Session session = + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId).blockingGet(); + + assertNotNull(session); + assertNotNull(session.state()); + } + + @Test + public void testAppStateSharedAcrossSessions() { + String sessionId1 = "shared-app-1"; + String sessionId2 = "shared-app-2"; + + ConcurrentMap state1 = new ConcurrentHashMap<>(); + state1.put(State.APP_PREFIX + "shared_counter", 10); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state1, sessionId1).blockingGet(); + + ConcurrentMap state2 = new ConcurrentHashMap<>(); + state2.put(State.APP_PREFIX + "shared_counter", 20); + + sessionService.createSession(TEST_APP_NAME, "other-user", state2, sessionId2).blockingGet(); + + Session retrieved1 = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId1, Optional.empty()) + .blockingGet(); + + assertEquals(20, retrieved1.state().get(State.APP_PREFIX + "shared_counter")); + } + + @Test + public void testUserStateIsolatedBetweenUsers() { + String sessionId1 = "user-isolated-1"; + String sessionId2 = "user-isolated-2"; + + ConcurrentMap state1 = new ConcurrentHashMap<>(); + state1.put(State.USER_PREFIX + "preference", "dark"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state1, sessionId1).blockingGet(); + + ConcurrentMap state2 = new ConcurrentHashMap<>(); + state2.put(State.USER_PREFIX + "preference", "light"); + + sessionService.createSession(TEST_APP_NAME, "other-user", state2, sessionId2).blockingGet(); + + Session retrieved1 = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId1, Optional.empty()) + .blockingGet(); + Session retrieved2 = + sessionService + .getSession(TEST_APP_NAME, "other-user", sessionId2, Optional.empty()) + .blockingGet(); + + assertEquals("dark", retrieved1.state().get(State.USER_PREFIX + "preference")); + assertEquals("light", retrieved2.state().get(State.USER_PREFIX + "preference")); + } + + @Test + public void testComplexNestedState() { + String sessionId = "nested-state"; + ConcurrentMap state = new ConcurrentHashMap<>(); + + ConcurrentHashMap nestedObject = new ConcurrentHashMap<>(); + nestedObject.put("nested_key", "nested_value"); + nestedObject.put("nested_number", 42); + + state.put(State.APP_PREFIX + "config", nestedObject); + state.put("simple", "value"); + + Session session = + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId).blockingGet(); + + assertNotNull(session); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + Object configObj = retrieved.state().get(State.APP_PREFIX + "config"); + assertNotNull(configObj); + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/DatabaseSessionServiceTest.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/DatabaseSessionServiceTest.java new file mode 100644 index 000000000..01d6507b5 --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/DatabaseSessionServiceTest.java @@ -0,0 +1,693 @@ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.google.adk.events.Event; +import com.google.adk.events.EventActions; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.ResultSet; +import java.sql.Statement; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class DatabaseSessionServiceTest { + + private static final String TEST_DB_URL = "jdbc:h2:mem:testdb;DB_CLOSE_DELAY=-1;MODE=PostgreSQL"; + private static final String TEST_APP_NAME = "test-app"; + private static final String TEST_USER_ID = "test-user"; + + private DatabaseSessionService sessionService; + + @BeforeEach + public void setUp() { + sessionService = new DatabaseSessionService(TEST_DB_URL); + } + + @AfterEach + public void tearDown() { + if (sessionService != null) { + try (Connection conn = DriverManager.getConnection(TEST_DB_URL); + Statement stmt = conn.createStatement()) { + stmt.execute("DELETE FROM events"); + stmt.execute("DELETE FROM sessions"); + stmt.execute("DELETE FROM app_states"); + stmt.execute("DELETE FROM user_states"); + } catch (Exception e) { + } + sessionService.close(); + } + } + + private long countEventsInDatabase(String sessionId) throws Exception { + try (Connection conn = DriverManager.getConnection(TEST_DB_URL); + Statement stmt = conn.createStatement(); + ResultSet rs = + stmt.executeQuery( + "SELECT COUNT(*) FROM events WHERE session_id = '" + sessionId + "'")) { + if (rs.next()) { + return rs.getLong(1); + } + return 0; + } + } + + private long countSessionsInDatabase(String sessionId) throws Exception { + try (Connection conn = DriverManager.getConnection(TEST_DB_URL); + Statement stmt = conn.createStatement(); + ResultSet rs = + stmt.executeQuery("SELECT COUNT(*) FROM sessions WHERE id = '" + sessionId + "'")) { + if (rs.next()) { + return rs.getLong(1); + } + return 0; + } + } + + @Test + public void testCreateSession() { + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put("key1", "value1"); + state.put("key2", 42); + + Session session = + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state, null).blockingGet(); + + assertNotNull(session); + assertNotNull(session.id()); + assertEquals(TEST_APP_NAME, session.appName()); + assertEquals(TEST_USER_ID, session.userId()); + assertEquals("value1", session.state().get("key1")); + assertEquals(42, session.state().get("key2")); + assertTrue(session.events().isEmpty()); + } + + @Test + public void testCreateSessionWithId() { + String sessionId = "custom-session-id"; + ConcurrentHashMap state = new ConcurrentHashMap<>(); + + Session session = + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId).blockingGet(); + + assertNotNull(session); + assertEquals(sessionId, session.id()); + assertEquals(TEST_APP_NAME, session.appName()); + assertEquals(TEST_USER_ID, session.userId()); + } + + @Test + public void testGetSession() { + String sessionId = "get-session-test"; + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put("key", "value"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId).blockingGet(); + + Session retrievedSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(retrievedSession); + assertEquals(sessionId, retrievedSession.id()); + assertEquals(TEST_APP_NAME, retrievedSession.appName()); + assertEquals(TEST_USER_ID, retrievedSession.userId()); + assertEquals("value", retrievedSession.state().get("key")); + } + + @Test + public void testGetSessionNotFound() { + String nonExistentId = "non-existent"; + + assertNull( + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, nonExistentId, Optional.empty()) + .blockingGet()); + } + + @Test + public void testListSessionsEmpty() { + ListSessionsResponse response = + sessionService.listSessions(TEST_APP_NAME, TEST_USER_ID).blockingGet(); + + assertNotNull(response); + assertEquals(0, response.sessions().size()); + } + + @Test + public void testListSessions() { + String sessionId1 = "list-test-1"; + String sessionId2 = "list-test-2"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId1) + .blockingGet(); + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId2) + .blockingGet(); + + ListSessionsResponse response = + sessionService.listSessions(TEST_APP_NAME, TEST_USER_ID).blockingGet(); + + assertNotNull(response); + List sessions = response.sessions(); + assertEquals(2, sessions.size()); + assertTrue(sessions.stream().anyMatch(s -> s.id().equals(sessionId1))); + assertTrue(sessions.stream().anyMatch(s -> s.id().equals(sessionId2))); + } + + @Test + public void testAppendEvent() { + String sessionId = "event-test"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test-author") + .content(Content.fromParts(Part.fromText("Hello, world!"))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, event).blockingGet(); + Session updatedSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(updatedSession); + assertEquals(1, updatedSession.events().size()); + Event retrievedEvent = updatedSession.events().get(0); + assertEquals(event.id(), retrievedEvent.id()); + assertEquals(event.author(), retrievedEvent.author()); + } + + @Test + public void testAppendEventToNonExistentSession() { + String nonExistentId = "non-existent"; + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test-author") + .content(Content.fromParts(Part.fromText("Hello, world!"))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Session nonExistentSession = + Session.builder(nonExistentId) + .appName(TEST_APP_NAME) + .userId(TEST_USER_ID) + .state(new ConcurrentHashMap<>()) + .events(new ArrayList<>()) + .build(); + assertThrows( + SessionNotFoundException.class, + () -> sessionService.appendEvent(nonExistentSession, event).blockingGet()); + } + + @Test + public void testDeleteSession() { + String sessionId = "delete-test"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + assertNotNull( + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet()); + + sessionService.deleteSession(TEST_APP_NAME, TEST_USER_ID, sessionId).blockingAwait(); + + assertNull( + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet()); + } + + @Test + public void testListEvents() { + String sessionId = "list-events-test"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + for (int i = 1; i <= 5; i++) { + Event event = + Event.builder() + .id("event-" + i) + .author("test-author") + .content(Content.fromParts(Part.fromText("index: " + i))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, event).blockingGet(); + try { + TimeUnit.MILLISECONDS.sleep(10); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + + ListEventsResponse response = + sessionService.listEvents(TEST_APP_NAME, TEST_USER_ID, sessionId).blockingGet(); + + assertNotNull(response); + assertEquals(5, response.events().size()); + } + + @Test + public void testGetSessionWithNumRecentEvents() { + String sessionId = "filter-test"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + Instant startTime = Instant.now(); + + for (int i = 1; i <= 5; i++) { + Event event = + Event.builder() + .id("event-" + i) + .author("test-author") + .content(Content.fromParts(Part.fromText("index: " + i))) + .timestamp(startTime.plusSeconds(i).toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, event).blockingGet(); + } + + GetSessionConfig config = GetSessionConfig.builder().numRecentEvents(2).build(); + Session sessionWithRecentEvents = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.of(config)) + .blockingGet(); + + assertNotNull(sessionWithRecentEvents); + assertEquals(2, sessionWithRecentEvents.events().size()); + } + + @Test + public void testAppendEventUpdatesSessionState() { + Session session = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), "session1") + .blockingGet(); + + ConcurrentHashMap stateDelta = new ConcurrentHashMap<>(); + stateDelta.put("sessionKey", "sessionValue"); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test-author") + .content(Content.fromParts(Part.fromText("Test event"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(EventActions.builder().stateDelta(stateDelta).build()) + .build(); + + sessionService.appendEvent(session, event).blockingGet(); + + Session retrievedSession = + sessionService + .getSession(session.appName(), session.userId(), session.id(), Optional.empty()) + .blockingGet(); + assertNotNull(retrievedSession); + assertEquals("sessionValue", retrievedSession.state().get("sessionKey")); + } + + @Test + public void testAppendEventUpdatesAppState() { + Session session = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), "session2") + .blockingGet(); + + ConcurrentHashMap stateDelta = new ConcurrentHashMap<>(); + stateDelta.put("_app_appKey", "appValue"); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test-author") + .content(Content.fromParts(Part.fromText("Test event"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(EventActions.builder().stateDelta(stateDelta).build()) + .build(); + + sessionService.appendEvent(session, event).blockingGet(); + + Session retrievedSession = + sessionService + .getSession(session.appName(), session.userId(), session.id(), Optional.empty()) + .blockingGet(); + assertNotNull(retrievedSession); + assertEquals("appValue", retrievedSession.state().get("_app_appKey")); + } + + @Test + public void testAppendEventUpdatesUserState() { + Session session = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), "session3") + .blockingGet(); + + ConcurrentHashMap stateDelta = new ConcurrentHashMap<>(); + stateDelta.put("_user_userKey", "userValue"); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test-author") + .content(Content.fromParts(Part.fromText("Test event"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(EventActions.builder().stateDelta(stateDelta).build()) + .build(); + + sessionService.appendEvent(session, event).blockingGet(); + + Session retrievedSession = + sessionService + .getSession(session.appName(), session.userId(), session.id(), Optional.empty()) + .blockingGet(); + assertNotNull(retrievedSession); + assertEquals("userValue", retrievedSession.state().get("_user_userKey")); + } + + @Test + public void testDeleteSessionRemovesAllRelatedData() throws Exception { + String sessionId = "delete-cascade-test"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + for (int i = 1; i <= 5; i++) { + Event event = + Event.builder() + .id("event-" + i) + .author("test-author") + .content(Content.fromParts(Part.fromText("Event " + i))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, event).blockingGet(); + } + + long eventsBefore = countEventsInDatabase(sessionId); + assertEquals(5, eventsBefore, "Should have 5 events before deletion"); + + long sessionsBefore = countSessionsInDatabase(sessionId); + assertEquals(1, sessionsBefore, "Should have 1 session before deletion"); + + sessionService.deleteSession(TEST_APP_NAME, TEST_USER_ID, sessionId).blockingAwait(); + + long eventsAfter = countEventsInDatabase(sessionId); + assertEquals(0, eventsAfter, "All events should be deleted from database"); + + long sessionsAfter = countSessionsInDatabase(sessionId); + assertEquals(0, sessionsAfter, "Session should be deleted from database"); + } + + @Test + public void testEventsPersistAfterMultipleReads() throws Exception { + String sessionId = "persist-after-reads-test"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + for (int i = 1; i <= 3; i++) { + Event event = + Event.builder() + .id("event-" + i) + .author("test-author") + .content(Content.fromParts(Part.fromText("Event " + i))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, event).blockingGet(); + } + + long eventsBeforeReads = countEventsInDatabase(sessionId); + assertEquals(3, eventsBeforeReads); + + for (int i = 0; i < 5; i++) { + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + } + + long eventsAfterReads = countEventsInDatabase(sessionId); + assertEquals(3, eventsAfterReads, "Events should persist in database after multiple reads"); + } + + @Test + public void testAppendEventWithNullContent() { + String sessionId = "null-content-test"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + Event eventWithNullContent = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test-author") + .content(Optional.empty()) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, eventWithNullContent).blockingGet(); + Session updatedSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(updatedSession); + assertEquals(1, updatedSession.events().size()); + } + + @Test + public void testEmptyStateDelta() { + String sessionId = "empty-delta-test"; + Session session = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + Event eventWithEmptyDelta = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test-author") + .content(Content.fromParts(Part.fromText("Test"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(EventActions.builder().stateDelta(new ConcurrentHashMap<>()).build()) + .build(); + + sessionService.appendEvent(session, eventWithEmptyDelta).blockingGet(); + Session updatedSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(updatedSession); + assertEquals(1, updatedSession.events().size()); + } + + @Test + public void testNullStateDeltaHandling() { + String sessionId = "null-delta-test"; + Session session = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + Event eventWithNullActions = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test-author") + .content(Content.fromParts(Part.fromText("Test"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(null) + .build(); + + sessionService.appendEvent(session, eventWithNullActions).blockingGet(); + Session updatedSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(updatedSession); + assertEquals(1, updatedSession.events().size()); + } + + @Test + public void testAppendEventWithRemovedDeletesKeys() throws Exception { + String sessionId = UUID.randomUUID().toString(); + + ConcurrentHashMap initialState = new ConcurrentHashMap<>(); + initialState.put(State.APP_PREFIX + "app_key", "app_value"); + initialState.put(State.USER_PREFIX + "user_key", "user_value"); + initialState.put("session_key", "session_value"); + + Session session = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, initialState, sessionId) + .blockingGet(); + + assertNotNull(session); + assertEquals("app_value", session.state().get(State.APP_PREFIX + "app_key")); + assertEquals("user_value", session.state().get(State.USER_PREFIX + "user_key")); + assertEquals("session_value", session.state().get("session_key")); + + ConcurrentHashMap delta = new ConcurrentHashMap<>(); + delta.put(State.APP_PREFIX + "app_key", State.REMOVED); + delta.put(State.USER_PREFIX + "user_key", State.REMOVED); + delta.put("session_key", State.REMOVED); + + EventActions actions = EventActions.builder().stateDelta(delta).build(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Remove keys"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(actions) + .build(); + + sessionService.appendEvent(session, event).blockingGet(); + + Session updated = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(updated); + assertFalse(updated.state().containsKey(State.APP_PREFIX + "app_key")); + assertFalse(updated.state().containsKey(State.USER_PREFIX + "user_key")); + assertFalse(updated.state().containsKey("session_key")); + + try (Connection conn = DriverManager.getConnection(TEST_DB_URL); + Statement stmt = conn.createStatement()) { + + ResultSet rs = + stmt.executeQuery( + "SELECT state FROM app_states WHERE app_name = '" + TEST_APP_NAME + "'"); + if (rs.next()) { + String appStateJson = rs.getString("state"); + assertFalse(appStateJson.contains("app_key"), "app_key should be removed from database"); + } + + rs = + stmt.executeQuery( + "SELECT state FROM user_states WHERE app_name = '" + + TEST_APP_NAME + + "' AND user_id = '" + + TEST_USER_ID + + "'"); + if (rs.next()) { + String userStateJson = rs.getString("state"); + assertFalse(userStateJson.contains("user_key"), "user_key should be removed from database"); + } + + rs = stmt.executeQuery("SELECT state FROM sessions WHERE id = '" + sessionId + "'"); + if (rs.next()) { + String sessionStateJson = rs.getString("state"); + assertFalse( + sessionStateJson.contains("session_key"), + "session_key should be removed from database"); + } + } + } + + @Test + public void testRemovedOnlyAffectsSpecifiedTier() throws Exception { + String sessionId = UUID.randomUUID().toString(); + + ConcurrentHashMap initialState = new ConcurrentHashMap<>(); + initialState.put(State.APP_PREFIX + "app_keep", "app_value"); + initialState.put(State.APP_PREFIX + "app_remove", "remove_this"); + initialState.put(State.USER_PREFIX + "user_keep", "user_value"); + initialState.put(State.USER_PREFIX + "user_remove", "remove_this"); + initialState.put("session_keep", "session_value"); + initialState.put("session_remove", "remove_this"); + + Session session = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, initialState, sessionId) + .blockingGet(); + + ConcurrentHashMap delta = new ConcurrentHashMap<>(); + delta.put(State.APP_PREFIX + "app_remove", State.REMOVED); + delta.put(State.USER_PREFIX + "user_remove", State.REMOVED); + delta.put("session_remove", State.REMOVED); + + EventActions actions = EventActions.builder().stateDelta(delta).build(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Selective removal"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(actions) + .build(); + + sessionService.appendEvent(session, event).blockingGet(); + + Session updated = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(updated); + + assertEquals("app_value", updated.state().get(State.APP_PREFIX + "app_keep")); + assertFalse(updated.state().containsKey(State.APP_PREFIX + "app_remove")); + + assertEquals("user_value", updated.state().get(State.USER_PREFIX + "user_keep")); + assertFalse(updated.state().containsKey(State.USER_PREFIX + "user_remove")); + + assertEquals("session_value", updated.state().get("session_keep")); + assertFalse(updated.state().containsKey("session_remove")); + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/DiagnosticTest.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/DiagnosticTest.java new file mode 100644 index 000000000..82a6a8591 --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/DiagnosticTest.java @@ -0,0 +1,112 @@ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import com.google.adk.events.Event; +import com.google.adk.testing.TestDatabaseConfig; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import java.time.Instant; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; + +@Tag("integration") +public class DiagnosticTest { + private static final String TEST_DB_URL = TestDatabaseConfig.MYSQL_JDBC_URL; + private static final String TEST_APP_NAME = "diagnostic-test"; + private static final String TEST_USER_ID = "diagnostic-user"; + private DatabaseSessionService sessionService; + + @BeforeEach + public void setUp() { + assumeTrue( + TestDatabaseConfig.isMySQLAvailable(), + TestDatabaseConfig.getDatabaseNotAvailableMessage("MySQL")); + sessionService = new DatabaseSessionService(TEST_DB_URL, new java.util.HashMap<>()); + } + + @AfterEach + public void tearDown() { + if (sessionService != null) { + sessionService.close(); + } + } + + @Test + public void diagnosticAfterTimestampFiltering() { + String sessionId = "diag-" + System.currentTimeMillis(); + Session session = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + Instant baseTime = Instant.now(); + System.out.println("Base time: " + baseTime); + + // Create 10 events + for (int i = 1; i <= 10; i++) { + Instant eventTime = baseTime.plusSeconds(i); + Event event = + Event.builder() + .id("event-" + i) + .author("test") + .content(Content.fromParts(Part.fromText("Event " + i))) + .timestamp(eventTime.toEpochMilli()) + .build(); + session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, event).blockingGet(); + System.out.println( + "Created event-" + + i + + " with timestamp: " + + eventTime + + " (" + + eventTime.toEpochMilli() + + ")"); + try { + TimeUnit.MILLISECONDS.sleep(5); + } catch (InterruptedException e) { + } + } + + // Get all events + Session allEvents = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + System.out.println("\n=== ALL EVENTS ==="); + for (Event e : allEvents.events()) { + System.out.println( + e.id() + ": " + e.timestamp() + " (" + Instant.ofEpochMilli(e.timestamp()) + ")"); + } + + // Filter after 5 seconds + Instant threshold = baseTime.plusSeconds(5); + System.out.println( + "\n=== FILTERING AFTER: " + threshold + " (" + threshold.toEpochMilli() + ") ==="); + GetSessionConfig config = GetSessionConfig.builder().afterTimestamp(threshold).build(); + Session filtered = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.of(config)) + .blockingGet(); + + System.out.println("Expected ~5 events (event-6 through event-10)"); + System.out.println("Actually got: " + filtered.events().size() + " events"); + for (Event e : filtered.events()) { + System.out.println( + " " + e.id() + ": " + e.timestamp() + " (" + Instant.ofEpochMilli(e.timestamp()) + ")"); + } + + assertEquals(5, filtered.events().size(), "Should get exactly 5 events after threshold"); + assertEquals("event-6", filtered.events().get(0).id()); + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/DialectDetectorTest.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/DialectDetectorTest.java new file mode 100644 index 000000000..a2dbbb24a --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/DialectDetectorTest.java @@ -0,0 +1,119 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import com.google.adk.sessions.dialect.DialectDetector; +import com.google.adk.sessions.dialect.H2Dialect; +import com.google.adk.sessions.dialect.MySqlDialect; +import com.google.adk.sessions.dialect.PostgresDialect; +import com.google.adk.sessions.dialect.SpannerDialect; +import com.google.adk.sessions.dialect.SqlDialect; +import org.junit.jupiter.api.Test; + +public class DialectDetectorTest { + + @Test + public void testDetectPostgreSQLDialectFromUrl() { + String url = "jdbc:postgresql://localhost:5432/testdb"; + SqlDialect dialect = DialectDetector.detectFromJdbcUrl(url); + assertEquals(PostgresDialect.class, dialect.getClass()); + assertEquals("PostgreSQL", dialect.dialectName()); + } + + @Test + public void testDetectMySQLDialectFromUrl() { + String url = "jdbc:mysql://localhost:3306/testdb"; + SqlDialect dialect = DialectDetector.detectFromJdbcUrl(url); + assertEquals(MySqlDialect.class, dialect.getClass()); + assertEquals("MySQL", dialect.dialectName()); + } + + @Test + public void testDetectH2DialectFromUrl() { + String url = "jdbc:h2:mem:testdb"; + SqlDialect dialect = DialectDetector.detectFromJdbcUrl(url); + assertEquals(H2Dialect.class, dialect.getClass()); + assertEquals("H2", dialect.dialectName()); + } + + @Test + public void testDetectSpannerDialectFromUrl() { + String url = "jdbc:cloudspanner:/projects/test/instances/test/databases/test"; + SqlDialect dialect = DialectDetector.detectFromJdbcUrl(url); + assertEquals(SpannerDialect.class, dialect.getClass()); + assertEquals("Cloud Spanner", dialect.dialectName()); + } + + @Test + public void testDetectDialectWithParametersInUrl() { + String url = "jdbc:postgresql://localhost:5432/testdb?user=admin&password=secret"; + SqlDialect dialect = DialectDetector.detectFromJdbcUrl(url); + assertEquals(PostgresDialect.class, dialect.getClass()); + } + + @Test + public void testDetectDialectUnsupportedDatabase() { + String url = "jdbc:oracle:thin:@localhost:1521:testdb"; + assertThrows(IllegalArgumentException.class, () -> DialectDetector.detectFromJdbcUrl(url)); + } + + @Test + public void testDetectH2InMemoryDatabase() { + String url = "jdbc:h2:mem:testdb;DB_CLOSE_DELAY=-1"; + SqlDialect dialect = DialectDetector.detectFromJdbcUrl(url); + assertEquals(H2Dialect.class, dialect.getClass()); + } + + @Test + public void testDetectH2FileDatabase() { + String url = "jdbc:h2:file:/data/testdb"; + SqlDialect dialect = DialectDetector.detectFromJdbcUrl(url); + assertEquals(H2Dialect.class, dialect.getClass()); + } + + @Test + public void testDetectPostgreSQLWithSSL() { + String url = + "jdbc:postgresql://localhost:5432/testdb?ssl=true&sslfactory=org.postgresql.ssl.NonValidatingFactory"; + SqlDialect dialect = DialectDetector.detectFromJdbcUrl(url); + assertEquals(PostgresDialect.class, dialect.getClass()); + } + + @Test + public void testDetectMySQLWithUTF8() { + String url = "jdbc:mysql://localhost:3306/testdb?useUnicode=true&characterEncoding=UTF-8"; + SqlDialect dialect = DialectDetector.detectFromJdbcUrl(url); + assertEquals(MySqlDialect.class, dialect.getClass()); + } + + @Test + public void testDetectDialectCaseInsensitive() { + String urlUpper = "JDBC:POSTGRESQL://localhost:5432/testdb"; + SqlDialect dialect = DialectDetector.detectFromJdbcUrl(urlUpper); + assertEquals(PostgresDialect.class, dialect.getClass()); + } + + @Test + public void testDetectSpannerWithComplexUrl() { + String url = + "jdbc:cloudspanner:/projects/my-project/instances/my-instance/databases/my-database?credentials=/path/to/credentials.json"; + SqlDialect dialect = DialectDetector.detectFromJdbcUrl(url); + assertEquals(SpannerDialect.class, dialect.getClass()); + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/EventFilteringTest.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/EventFilteringTest.java new file mode 100644 index 000000000..10f1b2856 --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/EventFilteringTest.java @@ -0,0 +1,297 @@ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.google.adk.events.Event; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import java.time.Instant; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class EventFilteringTest { + + private static final String TEST_DB_URL = + "jdbc:h2:mem:filter_test;DB_CLOSE_DELAY=-1;MODE=PostgreSQL"; + private static final String TEST_APP_NAME = "filter-test-app"; + private static final String TEST_USER_ID = "filter-test-user"; + + private DatabaseSessionService sessionService; + + @BeforeEach + public void setUp() { + sessionService = new DatabaseSessionService(TEST_DB_URL); + } + + @AfterEach + public void tearDown() { + if (sessionService != null) { + sessionService.close(); + } + } + + @Test + public void testFilterByNumRecentEvents() throws InterruptedException { + String sessionId = "recent-events-test"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + for (int i = 1; i <= 10; i++) { + Event event = + Event.builder() + .id("event-" + i) + .author("test-author") + .content(Content.fromParts(Part.fromText("Event " + i))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + TimeUnit.MILLISECONDS.sleep(10); + } + + GetSessionConfig config = GetSessionConfig.builder().numRecentEvents(3).build(); + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.of(config)) + .blockingGet(); + + assertNotNull(session); + assertEquals(3, session.events().size()); + assertEquals("event-8", session.events().get(0).id()); + assertEquals("event-9", session.events().get(1).id()); + assertEquals("event-10", session.events().get(2).id()); + } + + @Test + public void testFilterByAfterTimestamp() throws InterruptedException { + String sessionId = "timestamp-filter-test"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + Instant startTime = Instant.now(); + + for (int i = 1; i <= 5; i++) { + Event event = + Event.builder() + .id("event-" + i) + .author("test-author") + .content(Content.fromParts(Part.fromText("Event " + i))) + .timestamp(startTime.plusSeconds(i).toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + } + + GetSessionConfig config = + GetSessionConfig.builder().afterTimestamp(startTime.plusSeconds(3)).build(); + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.of(config)) + .blockingGet(); + + assertNotNull(session); + assertEquals(2, session.events().size()); + assertEquals("event-4", session.events().get(0).id()); + assertEquals("event-5", session.events().get(1).id()); + } + + @Test + public void testFilterByNumRecentEventsZero() { + String sessionId = "zero-events-test"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test-author") + .content(Content.fromParts(Part.fromText("Test"))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, event).blockingGet(); + + GetSessionConfig config = GetSessionConfig.builder().numRecentEvents(0).build(); + Session filteredSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.of(config)) + .blockingGet(); + + assertNotNull(filteredSession); + assertEquals(0, filteredSession.events().size()); + } + + @Test + public void testNoFilterReturnsAllEvents() throws InterruptedException { + String sessionId = "all-events-test"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + for (int i = 1; i <= 5; i++) { + Event event = + Event.builder() + .id("event-" + i) + .author("test-author") + .content(Content.fromParts(Part.fromText("Event " + i))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + TimeUnit.MILLISECONDS.sleep(10); + } + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(session); + assertEquals(5, session.events().size()); + } + + @Test + public void testCombinedFilters() throws InterruptedException { + String sessionId = "combined-filter-test"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + Instant startTime = Instant.now(); + + for (int i = 1; i <= 10; i++) { + Event event = + Event.builder() + .id("event-" + i) + .author("test-author") + .content(Content.fromParts(Part.fromText("Event " + i))) + .timestamp(startTime.plusSeconds(i).toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + } + + GetSessionConfig config = + GetSessionConfig.builder() + .afterTimestamp(startTime.plusSeconds(3)) + .numRecentEvents(3) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.of(config)) + .blockingGet(); + + assertNotNull(session); + assertTrue(session.events().size() <= 3); + } + + @Test + public void testNoFilterReturnsAllEventsLargeDataset() throws InterruptedException { + String sessionId = "large-dataset-test"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + for (int i = 1; i <= 50; i++) { + Event event = + Event.builder() + .id("event-" + i) + .author("test-author") + .content(Content.fromParts(Part.fromText("Event " + i))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + TimeUnit.MILLISECONDS.sleep(10); + } + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(session); + assertEquals(50, session.events().size()); + assertEquals("event-1", session.events().get(0).id()); + assertEquals("event-25", session.events().get(24).id()); + assertEquals("event-50", session.events().get(49).id()); + } + + @Test + public void testLimitedEventsFromLargeDatasetReturnsCorrectOrder() throws InterruptedException { + String sessionId = "limited-large-dataset-test"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + for (int i = 1; i <= 50; i++) { + Event event = + Event.builder() + .id("event-" + i) + .author("test-author") + .content(Content.fromParts(Part.fromText("Event " + i))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + TimeUnit.MILLISECONDS.sleep(10); + } + + GetSessionConfig config = GetSessionConfig.builder().numRecentEvents(20).build(); + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.of(config)) + .blockingGet(); + + assertNotNull(session); + assertEquals(20, session.events().size()); + assertEquals("event-31", session.events().get(0).id()); + assertEquals("event-40", session.events().get(9).id()); + assertEquals("event-50", session.events().get(19).id()); + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/FlywayMigrationTest.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/FlywayMigrationTest.java new file mode 100644 index 000000000..3179f6043 --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/FlywayMigrationTest.java @@ -0,0 +1,394 @@ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.HashSet; +import java.util.Set; +import org.junit.jupiter.api.Assumptions; +import org.junit.jupiter.api.Test; + +public class FlywayMigrationTest { + + @Test + public void testFlywayMigrationsApplied() { + String dbUrl = "jdbc:h2:mem:flyway_test;DB_CLOSE_DELAY=-1;MODE=PostgreSQL"; + + assertDoesNotThrow( + () -> { + try (DatabaseSessionService service = new DatabaseSessionService(dbUrl)) { + assertNotNull(service); + } + }); + + try (Connection conn = DriverManager.getConnection(dbUrl); + Statement stmt = conn.createStatement()) { + + ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM \"flyway_schema_history\""); + rs.next(); + int migrationCount = rs.getInt(1); + assert migrationCount > 0 : "Flyway migrations should be applied"; + + rs = + stmt.executeQuery( + "SELECT table_name FROM information_schema.tables WHERE table_schema = 'PUBLIC'"); + boolean hasAppStates = false; + boolean hasUserStates = false; + boolean hasSessions = false; + boolean hasEvents = false; + + while (rs.next()) { + String tableName = rs.getString("table_name"); + if (tableName.equalsIgnoreCase("APP_STATES")) hasAppStates = true; + if (tableName.equalsIgnoreCase("USER_STATES")) hasUserStates = true; + if (tableName.equalsIgnoreCase("SESSIONS")) hasSessions = true; + if (tableName.equalsIgnoreCase("EVENTS")) hasEvents = true; + } + + assert hasAppStates : "app_states table should exist"; + assert hasUserStates : "user_states table should exist"; + assert hasSessions : "sessions table should exist"; + assert hasEvents : "events table should exist"; + + } catch (Exception e) { + throw new RuntimeException("Failed to verify Flyway migrations", e); + } + } + + @Test + public void testMultipleServiceInstancesShareSchema() { + String dbUrl = "jdbc:h2:mem:shared_schema_test;DB_CLOSE_DELAY=-1;MODE=PostgreSQL"; + + try (DatabaseSessionService service1 = new DatabaseSessionService(dbUrl); + DatabaseSessionService service2 = new DatabaseSessionService(dbUrl)) { + + assertNotNull(service1); + assertNotNull(service2); + } + } + + @Test + public void testTenConcurrentServiceInstances() { + String dbUrl = "jdbc:h2:mem:ten_instances_test;DB_CLOSE_DELAY=-1;MODE=PostgreSQL"; + + // Create 10 service instances concurrently, simulating 10 Kubernetes pods starting + DatabaseSessionService[] services = new DatabaseSessionService[10]; + + try { + for (int i = 0; i < 10; i++) { + services[i] = new DatabaseSessionService(dbUrl); + assertNotNull(services[i], "Service instance " + i + " should be initialized"); + } + + // Verify all instances are operational + for (int i = 0; i < 10; i++) { + assertNotNull(services[i], "Service instance " + i + " should still be valid"); + } + + } finally { + // Clean up all instances + for (int i = 0; i < 10; i++) { + if (services[i] != null) { + services[i].close(); + } + } + } + } + + @Test + public void testFlywayMigrationPostgres() { + String jdbcUrl = + "jdbc:postgresql://localhost:5432/adk_flyway_test?user=adk_test&password=adk_test_password"; + + // Check if PostgreSQL is available + try { + DriverManager.getConnection(jdbcUrl).close(); + } catch (SQLException e) { + Assumptions.assumeTrue(false, "PostgreSQL not available - skipping test"); + return; + } + + // Verify tables DO NOT exist before migration + try (Connection conn = DriverManager.getConnection(jdbcUrl); + Statement stmt = conn.createStatement()) { + ResultSet rs = + stmt.executeQuery( + "SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = 'public' AND table_name IN ('sessions', 'events', 'app_states', 'user_states')"); + rs.next(); + int tableCount = rs.getInt(1); + assertTrue(tableCount == 0, "Tables should NOT exist before migration in PostgreSQL"); + } catch (SQLException e) { + throw new RuntimeException("Failed to verify pre-migration state", e); + } + + DatabaseSessionService[] services = new DatabaseSessionService[10]; + + try { + // Create 10 instances + for (int i = 0; i < 10; i++) { + services[i] = new DatabaseSessionService(jdbcUrl); + assertNotNull(services[i], "PostgreSQL instance " + i + " should be initialized"); + } + + // Verify migration was applied exactly once + try (Connection conn = DriverManager.getConnection(jdbcUrl); + Statement stmt = conn.createStatement()) { + ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM flyway_schema_history"); + rs.next(); + int migrationCount = rs.getInt(1); + assertTrue( + migrationCount > 0, "At least one migration should be applied to PostgreSQL database"); + } + + // Verify all expected tables exist AFTER migration + try (Connection conn = DriverManager.getConnection(jdbcUrl); + Statement stmt = conn.createStatement()) { + ResultSet rs = + stmt.executeQuery( + "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name IN ('sessions', 'events', 'app_states', 'user_states', 'adk_internal_metadata') ORDER BY table_name"); + + boolean hasSessions = false; + boolean hasEvents = false; + boolean hasAppStates = false; + boolean hasUserStates = false; + boolean hasMetadata = false; + + while (rs.next()) { + String tableName = rs.getString("table_name"); + if (tableName.equals("sessions")) hasSessions = true; + if (tableName.equals("events")) hasEvents = true; + if (tableName.equals("app_states")) hasAppStates = true; + if (tableName.equals("user_states")) hasUserStates = true; + if (tableName.equals("adk_internal_metadata")) hasMetadata = true; + } + + assertTrue(hasSessions, "sessions table should exist after migration"); + assertTrue(hasEvents, "events table should exist after migration"); + assertTrue(hasAppStates, "app_states table should exist after migration"); + assertTrue(hasUserStates, "user_states table should exist after migration"); + assertTrue(hasMetadata, "adk_internal_metadata table should exist after migration"); + } + + // Verify column schema for sessions table + try (Connection conn = DriverManager.getConnection(jdbcUrl)) { + ResultSet rs = conn.getMetaData().getColumns(null, "public", "sessions", null); + Set columnNames = new HashSet<>(); + while (rs.next()) { + columnNames.add(rs.getString("COLUMN_NAME")); + } + assertTrue(columnNames.contains("app_name"), "sessions should have app_name column"); + assertTrue(columnNames.contains("user_id"), "sessions should have user_id column"); + assertTrue(columnNames.contains("id"), "sessions should have id column"); + assertTrue(columnNames.contains("state"), "sessions should have state column"); + assertTrue(columnNames.contains("create_time"), "sessions should have create_time column"); + assertTrue(columnNames.contains("update_time"), "sessions should have update_time column"); + } + + // Verify column schema for events table + try (Connection conn = DriverManager.getConnection(jdbcUrl)) { + ResultSet rs = conn.getMetaData().getColumns(null, "public", "events", null); + Set columnNames = new HashSet<>(); + while (rs.next()) { + columnNames.add(rs.getString("COLUMN_NAME")); + } + assertTrue(columnNames.contains("id"), "events should have id column"); + assertTrue(columnNames.contains("app_name"), "events should have app_name column"); + assertTrue(columnNames.contains("user_id"), "events should have user_id column"); + assertTrue(columnNames.contains("session_id"), "events should have session_id column"); + assertTrue( + columnNames.contains("invocation_id"), "events should have invocation_id column"); + assertTrue(columnNames.contains("timestamp"), "events should have timestamp column"); + assertTrue(columnNames.contains("event_data"), "events should have event_data column"); + } + + // Verify foreign key from events to sessions + try (Connection conn = DriverManager.getConnection(jdbcUrl)) { + ResultSet rs = conn.getMetaData().getImportedKeys(null, "public", "events"); + boolean hasSessionFK = false; + while (rs.next()) { + String pkTable = rs.getString("PKTABLE_NAME"); + String fkTable = rs.getString("FKTABLE_NAME"); + if (pkTable.equals("sessions") && fkTable.equals("events")) { + hasSessionFK = true; + } + } + assertTrue(hasSessionFK, "events should have foreign key to sessions"); + } + + } catch (Exception e) { + throw new RuntimeException("PostgreSQL test failed", e); + } finally { + // Close all services + for (int i = 0; i < 10; i++) { + if (services[i] != null) { + services[i].close(); + } + } + + // Clean up database + try (Connection conn = DriverManager.getConnection(jdbcUrl); + Statement stmt = conn.createStatement()) { + stmt.execute("DROP TABLE IF EXISTS events CASCADE"); + stmt.execute("DROP TABLE IF EXISTS sessions CASCADE"); + stmt.execute("DROP TABLE IF EXISTS user_states CASCADE"); + stmt.execute("DROP TABLE IF EXISTS app_states CASCADE"); + stmt.execute("DROP TABLE IF EXISTS adk_internal_metadata CASCADE"); + stmt.execute("DROP TABLE IF EXISTS flyway_schema_history CASCADE"); + } catch (SQLException e) { + System.err.println("Failed to clean up PostgreSQL test database: " + e.getMessage()); + } + } + } + + @Test + public void testFlywayMigrationMysql() { + String jdbcUrl = + "jdbc:mysql://localhost:3306/adk_flyway_test?user=adk_test&password=adk_test_password"; + + // Check if MySQL is available + try { + DriverManager.getConnection(jdbcUrl).close(); + } catch (SQLException e) { + Assumptions.assumeTrue(false, "MySQL not available - skipping test"); + return; + } + + // Verify tables DO NOT exist before migration + try (Connection conn = DriverManager.getConnection(jdbcUrl); + Statement stmt = conn.createStatement()) { + ResultSet rs = + stmt.executeQuery( + "SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = 'adk_flyway_test' AND table_name IN ('sessions', 'events', 'app_states', 'user_states')"); + rs.next(); + int tableCount = rs.getInt(1); + assertTrue(tableCount == 0, "Tables should NOT exist before migration in MySQL"); + } catch (SQLException e) { + throw new RuntimeException("Failed to verify pre-migration state", e); + } + + DatabaseSessionService[] services = new DatabaseSessionService[10]; + + try { + // Create 10 instances + for (int i = 0; i < 10; i++) { + services[i] = new DatabaseSessionService(jdbcUrl); + assertNotNull(services[i], "MySQL instance " + i + " should be initialized"); + } + + // Verify migration was applied exactly once + try (Connection conn = DriverManager.getConnection(jdbcUrl); + Statement stmt = conn.createStatement()) { + ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM flyway_schema_history"); + rs.next(); + int migrationCount = rs.getInt(1); + assertTrue( + migrationCount > 0, "At least one migration should be applied to MySQL database"); + } + + // Verify all expected tables exist AFTER migration + try (Connection conn = DriverManager.getConnection(jdbcUrl); + Statement stmt = conn.createStatement()) { + ResultSet rs = + stmt.executeQuery( + "SELECT table_name FROM information_schema.tables WHERE table_schema = 'adk_flyway_test' AND table_name IN ('sessions', 'events', 'app_states', 'user_states', 'adk_internal_metadata') ORDER BY table_name"); + + boolean hasSessions = false; + boolean hasEvents = false; + boolean hasAppStates = false; + boolean hasUserStates = false; + boolean hasMetadata = false; + + while (rs.next()) { + String tableName = rs.getString("table_name"); + if (tableName.equals("sessions")) hasSessions = true; + if (tableName.equals("events")) hasEvents = true; + if (tableName.equals("app_states")) hasAppStates = true; + if (tableName.equals("user_states")) hasUserStates = true; + if (tableName.equals("adk_internal_metadata")) hasMetadata = true; + } + + assertTrue(hasSessions, "sessions table should exist after migration"); + assertTrue(hasEvents, "events table should exist after migration"); + assertTrue(hasAppStates, "app_states table should exist after migration"); + assertTrue(hasUserStates, "user_states table should exist after migration"); + assertTrue(hasMetadata, "adk_internal_metadata table should exist after migration"); + } + + // Verify column schema for sessions table + try (Connection conn = DriverManager.getConnection(jdbcUrl)) { + ResultSet rs = conn.getMetaData().getColumns(null, "adk_flyway_test", "sessions", null); + Set columnNames = new HashSet<>(); + while (rs.next()) { + columnNames.add(rs.getString("COLUMN_NAME")); + } + assertTrue(columnNames.contains("app_name"), "sessions should have app_name column"); + assertTrue(columnNames.contains("user_id"), "sessions should have user_id column"); + assertTrue(columnNames.contains("id"), "sessions should have id column"); + assertTrue(columnNames.contains("state"), "sessions should have state column"); + assertTrue(columnNames.contains("create_time"), "sessions should have create_time column"); + assertTrue(columnNames.contains("update_time"), "sessions should have update_time column"); + } + + // Verify column schema for events table + try (Connection conn = DriverManager.getConnection(jdbcUrl)) { + ResultSet rs = conn.getMetaData().getColumns(null, "adk_flyway_test", "events", null); + Set columnNames = new HashSet<>(); + while (rs.next()) { + columnNames.add(rs.getString("COLUMN_NAME")); + } + assertTrue(columnNames.contains("id"), "events should have id column"); + assertTrue(columnNames.contains("app_name"), "events should have app_name column"); + assertTrue(columnNames.contains("user_id"), "events should have user_id column"); + assertTrue(columnNames.contains("session_id"), "events should have session_id column"); + assertTrue( + columnNames.contains("invocation_id"), "events should have invocation_id column"); + assertTrue(columnNames.contains("timestamp"), "events should have timestamp column"); + assertTrue(columnNames.contains("event_data"), "events should have event_data column"); + } + + // Verify foreign key from events to sessions + try (Connection conn = DriverManager.getConnection(jdbcUrl); + Statement stmt = conn.createStatement()) { + ResultSet rs = + stmt.executeQuery( + "SELECT COUNT(*) FROM information_schema.KEY_COLUMN_USAGE " + + "WHERE TABLE_SCHEMA = 'adk_flyway_test' " + + "AND TABLE_NAME = 'events' " + + "AND REFERENCED_TABLE_NAME = 'sessions'"); + rs.next(); + assertTrue(rs.getInt(1) > 0, "events should have foreign key to sessions"); + } + + } catch (Exception e) { + throw new RuntimeException("MySQL test failed", e); + } finally { + // Close all services + for (int i = 0; i < 10; i++) { + if (services[i] != null) { + services[i].close(); + } + } + + // Clean up database + try (Connection conn = DriverManager.getConnection(jdbcUrl); + Statement stmt = conn.createStatement()) { + stmt.execute("SET FOREIGN_KEY_CHECKS = 0"); + stmt.execute("DROP TABLE IF EXISTS events"); + stmt.execute("DROP TABLE IF EXISTS sessions"); + stmt.execute("DROP TABLE IF EXISTS user_states"); + stmt.execute("DROP TABLE IF EXISTS app_states"); + stmt.execute("DROP TABLE IF EXISTS adk_internal_metadata"); + stmt.execute("DROP TABLE IF EXISTS flyway_schema_history"); + stmt.execute("SET FOREIGN_KEY_CHECKS = 1"); + } catch (SQLException e) { + System.err.println("Failed to clean up MySQL test database: " + e.getMessage()); + } + } + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/HikariConfigTest.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/HikariConfigTest.java new file mode 100644 index 000000000..a5111a5f7 --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/HikariConfigTest.java @@ -0,0 +1,216 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +import com.zaxxer.hikari.HikariDataSource; +import java.lang.reflect.Field; +import java.util.HashMap; +import java.util.Map; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + +public class HikariConfigTest { + + private static final String TEST_DB_URL = + "jdbc:h2:mem:hikari_test;DB_CLOSE_DELAY=-1;MODE=PostgreSQL"; + private DatabaseSessionService sessionService; + + @AfterEach + public void tearDown() { + if (sessionService != null) { + sessionService.close(); + } + } + + @Test + public void testDefaultHikariConfig() throws Exception { + sessionService = new DatabaseSessionService(TEST_DB_URL); + + HikariDataSource dataSource = getDataSource(sessionService); + assertNotNull(dataSource); + assertEquals(10, dataSource.getMaximumPoolSize()); + assertEquals(2, dataSource.getMinimumIdle()); + assertEquals(30000, dataSource.getConnectionTimeout()); + assertEquals(600000, dataSource.getIdleTimeout()); + assertEquals(1800000, dataSource.getMaxLifetime()); + } + + @Test + public void testCustomMaximumPoolSize() throws Exception { + Map properties = new HashMap<>(); + properties.put("hikari.maximumPoolSize", 20); + + sessionService = new DatabaseSessionService(TEST_DB_URL, properties); + + HikariDataSource dataSource = getDataSource(sessionService); + assertEquals(20, dataSource.getMaximumPoolSize()); + assertEquals(2, dataSource.getMinimumIdle()); + } + + @Test + public void testCustomMinimumIdle() throws Exception { + Map properties = new HashMap<>(); + properties.put("hikari.minimumIdle", 5); + + sessionService = new DatabaseSessionService(TEST_DB_URL, properties); + + HikariDataSource dataSource = getDataSource(sessionService); + assertEquals(10, dataSource.getMaximumPoolSize()); + assertEquals(5, dataSource.getMinimumIdle()); + } + + @Test + public void testCustomConnectionTimeout() throws Exception { + Map properties = new HashMap<>(); + properties.put("hikari.connectionTimeout", 60000L); + + sessionService = new DatabaseSessionService(TEST_DB_URL, properties); + + HikariDataSource dataSource = getDataSource(sessionService); + assertEquals(60000, dataSource.getConnectionTimeout()); + } + + @Test + public void testCustomIdleTimeout() throws Exception { + Map properties = new HashMap<>(); + properties.put("hikari.idleTimeout", 300000L); + + sessionService = new DatabaseSessionService(TEST_DB_URL, properties); + + HikariDataSource dataSource = getDataSource(sessionService); + assertEquals(300000, dataSource.getIdleTimeout()); + } + + @Test + public void testCustomMaxLifetime() throws Exception { + Map properties = new HashMap<>(); + properties.put("hikari.maxLifetime", 900000L); + + sessionService = new DatabaseSessionService(TEST_DB_URL, properties); + + HikariDataSource dataSource = getDataSource(sessionService); + assertEquals(900000, dataSource.getMaxLifetime()); + } + + @Test + public void testAllCustomHikariProperties() throws Exception { + Map properties = new HashMap<>(); + properties.put("hikari.maximumPoolSize", 25); + properties.put("hikari.minimumIdle", 10); + properties.put("hikari.connectionTimeout", 45000L); + properties.put("hikari.idleTimeout", 400000L); + properties.put("hikari.maxLifetime", 1200000L); + + sessionService = new DatabaseSessionService(TEST_DB_URL, properties); + + HikariDataSource dataSource = getDataSource(sessionService); + assertEquals(25, dataSource.getMaximumPoolSize()); + assertEquals(10, dataSource.getMinimumIdle()); + assertEquals(45000, dataSource.getConnectionTimeout()); + assertEquals(400000, dataSource.getIdleTimeout()); + assertEquals(1200000, dataSource.getMaxLifetime()); + } + + @Test + public void testNonHikariPropertiesArePassedToDataSource() throws Exception { + Map properties = new HashMap<>(); + properties.put("cachePrepStmts", "true"); + properties.put("prepStmtCacheSize", 250); + properties.put("hikari.maximumPoolSize", 15); + + sessionService = new DatabaseSessionService(TEST_DB_URL, properties); + + HikariDataSource dataSource = getDataSource(sessionService); + assertEquals(15, dataSource.getMaximumPoolSize()); + } + + @Test + public void testInvalidIntegerPropertyUsesDefault() throws Exception { + Map properties = new HashMap<>(); + properties.put("hikari.maximumPoolSize", "invalid"); + + sessionService = new DatabaseSessionService(TEST_DB_URL, properties); + + HikariDataSource dataSource = getDataSource(sessionService); + assertEquals(10, dataSource.getMaximumPoolSize()); + } + + @Test + public void testInvalidLongPropertyUsesDefault() throws Exception { + Map properties = new HashMap<>(); + properties.put("hikari.connectionTimeout", "invalid"); + + sessionService = new DatabaseSessionService(TEST_DB_URL, properties); + + HikariDataSource dataSource = getDataSource(sessionService); + assertEquals(30000, dataSource.getConnectionTimeout()); + } + + @Test + public void testIntegerAsNumberType() throws Exception { + Map properties = new HashMap<>(); + properties.put("hikari.maximumPoolSize", Integer.valueOf(30)); + + sessionService = new DatabaseSessionService(TEST_DB_URL, properties); + + HikariDataSource dataSource = getDataSource(sessionService); + assertEquals(30, dataSource.getMaximumPoolSize()); + } + + @Test + public void testLongAsNumberType() throws Exception { + Map properties = new HashMap<>(); + properties.put("hikari.connectionTimeout", Long.valueOf(50000L)); + + sessionService = new DatabaseSessionService(TEST_DB_URL, properties); + + HikariDataSource dataSource = getDataSource(sessionService); + assertEquals(50000, dataSource.getConnectionTimeout()); + } + + @Test + public void testStringNumberConversion() throws Exception { + Map properties = new HashMap<>(); + properties.put("hikari.maximumPoolSize", "35"); + properties.put("hikari.connectionTimeout", "40000"); + + sessionService = new DatabaseSessionService(TEST_DB_URL, properties); + + HikariDataSource dataSource = getDataSource(sessionService); + assertEquals(35, dataSource.getMaximumPoolSize()); + assertEquals(40000, dataSource.getConnectionTimeout()); + } + + @Test + public void testNullPropertyUsesDefault() throws Exception { + Map properties = new HashMap<>(); + properties.put("hikari.maximumPoolSize", null); + + sessionService = new DatabaseSessionService(TEST_DB_URL, properties); + + HikariDataSource dataSource = getDataSource(sessionService); + assertEquals(10, dataSource.getMaximumPoolSize()); + } + + private HikariDataSource getDataSource(DatabaseSessionService service) throws Exception { + Field dataSourceField = DatabaseSessionService.class.getDeclaredField("dataSource"); + dataSourceField.setAccessible(true); + return (HikariDataSource) dataSourceField.get(service); + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/ListSessionsEventsTest.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/ListSessionsEventsTest.java new file mode 100644 index 000000000..9973d6d31 --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/ListSessionsEventsTest.java @@ -0,0 +1,89 @@ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +import com.google.adk.events.Event; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import java.time.Instant; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class ListSessionsEventsTest { + + private static final String TEST_DB_URL = + "jdbc:h2:mem:list_test;DB_CLOSE_DELAY=-1;MODE=PostgreSQL"; + private static final String TEST_APP_NAME = "list-test-app"; + private static final String TEST_USER_ID = "list-test-user"; + + private DatabaseSessionService sessionService; + + @BeforeEach + public void setUp() { + sessionService = new DatabaseSessionService(TEST_DB_URL); + } + + @AfterEach + public void tearDown() { + if (sessionService != null) { + sessionService.close(); + } + } + + @Test + public void testListSessionsReturnsAllSessions() { + String userId = "list-sessions-user"; + for (int i = 1; i <= 50; i++) { + sessionService + .createSession(TEST_APP_NAME, userId, new ConcurrentHashMap<>(), "session-" + i) + .blockingGet(); + } + + ListSessionsResponse response = + sessionService.listSessions(TEST_APP_NAME, userId).blockingGet(); + + assertNotNull(response); + assertEquals(50, response.sessions().size()); + } + + @Test + public void testListEventsReturnsAllEvents() throws InterruptedException { + String userId = "list-events-user"; + String sessionId = "all-events-test"; + sessionService + .createSession(TEST_APP_NAME, userId, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + for (int i = 1; i <= 50; i++) { + Event event = + Event.builder() + .id("event-" + i) + .author("test-author") + .content(Content.fromParts(Part.fromText("Event " + i))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, userId, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + TimeUnit.MILLISECONDS.sleep(10); + } + + ListEventsResponse response = + sessionService.listEvents(TEST_APP_NAME, userId, sessionId).blockingGet(); + + assertNotNull(response); + assertEquals(50, response.events().size()); + assertEquals("event-1", response.events().get(0).id()); + assertEquals("event-25", response.events().get(24).id()); + assertEquals("event-50", response.events().get(49).id()); + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/MySQLAgentIntegrationTest.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/MySQLAgentIntegrationTest.java new file mode 100644 index 000000000..f8afa272b --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/MySQLAgentIntegrationTest.java @@ -0,0 +1,355 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import com.google.adk.agents.InvocationContext; +import com.google.adk.agents.LlmAgent; +import com.google.adk.agents.SequentialAgent; +import com.google.adk.events.Event; +import com.google.adk.models.LlmResponse; +import com.google.adk.testing.TestDatabaseConfig; +import com.google.adk.testing.TestLlm; +import com.google.common.collect.ImmutableList; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; + +/** + * Integration tests for Agents using DatabaseSessionService with real MySQL 8.0 database. + * + *

This test suite verifies that agents work correctly with MySQL-backed session persistence, + * including: - Sequential agent execution with database persistence - State propagation between + * agents via outputKey - Event storage and retrieval - App/user/session state management with + * database backend + * + *

Prerequisites: Start MySQL test database with: + * + *

{@code
+ * docker-compose -f scripts/docker-compose.test.yml up -d mysql-test
+ * }
+ * + *

Configuration: - Host: localhost:3307 - Database: adk_test - User: adk_user - Password: + * adk_password + */ +@Tag("integration") +public class MySQLAgentIntegrationTest { + + private static final String TEST_DB_URL = TestDatabaseConfig.MYSQL_JDBC_URL; + private static final String TEST_APP_NAME = "mysql-agent-integration-test"; + private static final String TEST_USER_ID = "agent-test-user"; + + private DatabaseSessionService sessionService; + + @BeforeEach + public void setUp() { + assumeTrue( + TestDatabaseConfig.isMySQLAvailable(), + TestDatabaseConfig.getDatabaseNotAvailableMessage("MySQL")); + + sessionService = new DatabaseSessionService(TEST_DB_URL); + } + + @AfterEach + public void tearDown() { + if (sessionService != null) { + sessionService.close(); + } + } + + @Test + public void testSequentialAgentWithDatabasePersistence() { + Content agentAResponse = Content.fromParts(Part.fromText("The topic is: AI")); + TestLlm llmA = + new TestLlm( + ImmutableList.of( + LlmResponse.builder() + .content(agentAResponse) + .partial(false) + .turnComplete(true) + .build())); + + LlmAgent agentA = + LlmAgent.builder() + .name("AgentA") + .model(llmA) + .instruction("Extract topic") + .outputKey("topic") + .build(); + + Content agentBResponse = Content.fromParts(Part.fromText("Summary: AI is important")); + TestLlm llmB = + new TestLlm( + ImmutableList.of( + LlmResponse.builder() + .content(agentBResponse) + .partial(false) + .turnComplete(true) + .build())); + + LlmAgent agentB = + LlmAgent.builder() + .name("AgentB") + .model(llmB) + .instruction("Summarize topic: ${topic}") + .outputKey("summary") + .build(); + + SequentialAgent sequential = + SequentialAgent.builder() + .name("SequentialAgent") + .subAgents(ImmutableList.of(agentA, agentB)) + .build(); + + Session session = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), null) + .blockingGet(); + + InvocationContext ctx = + InvocationContext.builder() + .sessionService(sessionService) + .session(session) + .agent(sequential) + .build(); + + List events = + sequential + .runAsync(ctx) + .flatMap(event -> ctx.sessionService().appendEvent(ctx.session(), event).toFlowable()) + .toList() + .blockingGet(); + + assertNotNull(events); + assertTrue(events.size() >= 2, "Expected at least 2 events from sequential agents"); + + assertEquals("The topic is: AI", ctx.session().state().get("topic")); + assertEquals("Summary: AI is important", ctx.session().state().get("summary")); + + Session retrievedSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, session.id(), Optional.empty()) + .blockingGet(); + + assertNotNull(retrievedSession); + assertEquals(session.id(), retrievedSession.id()); + assertEquals("The topic is: AI", retrievedSession.state().get("topic")); + assertEquals("Summary: AI is important", retrievedSession.state().get("summary")); + assertTrue( + retrievedSession.events().size() >= 2, "Expected at least 2 events persisted in database"); + } + + @Test + public void testAgentWithAppAndUserStatePersistence() { + Content configResponse = Content.fromParts(Part.fromText("{\"version\": \"1.0\"}")); + TestLlm llmA = + new TestLlm( + ImmutableList.of( + LlmResponse.builder() + .content(configResponse) + .partial(false) + .turnComplete(true) + .build())); + + LlmAgent agentA = + LlmAgent.builder() + .name("ConfigAgent") + .model(llmA) + .instruction("Return config") + .outputKey("app:config") + .build(); + + Content prefResponse = Content.fromParts(Part.fromText("dark")); + TestLlm llmB = + new TestLlm( + ImmutableList.of( + LlmResponse.builder() + .content(prefResponse) + .partial(false) + .turnComplete(true) + .build())); + + LlmAgent agentB = + LlmAgent.builder() + .name("PreferenceAgent") + .model(llmB) + .instruction("Use config: ${app:config}") + .outputKey("user:theme") + .build(); + + SequentialAgent sequential = + SequentialAgent.builder() + .name("Sequential") + .subAgents(ImmutableList.of(agentA, agentB)) + .build(); + + Session session = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), null) + .blockingGet(); + + InvocationContext ctx = + InvocationContext.builder() + .sessionService(sessionService) + .session(session) + .agent(sequential) + .build(); + + sequential + .runAsync(ctx) + .flatMap(event -> ctx.sessionService().appendEvent(ctx.session(), event).toFlowable()) + .toList() + .blockingGet(); + + assertEquals("{\"version\": \"1.0\"}", ctx.session().state().get("app:config")); + assertEquals("dark", ctx.session().state().get("user:theme")); + + Session retrievedSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, session.id(), Optional.empty()) + .blockingGet(); + + assertNotNull(retrievedSession); + assertEquals("{\"version\": \"1.0\"}", retrievedSession.state().get("app:config")); + assertEquals("dark", retrievedSession.state().get("user:theme")); + } + + @Test + public void testAgentStatePersistedAcrossSessions() { + Content response1 = Content.fromParts(Part.fromText("User preference stored")); + TestLlm llm1 = + new TestLlm( + ImmutableList.of( + LlmResponse.builder() + .content(response1) + .partial(false) + .turnComplete(true) + .build())); + + LlmAgent agent1 = + LlmAgent.builder() + .name("PreferenceAgent") + .model(llm1) + .instruction("Store preference") + .outputKey("user:language") + .build(); + + Session session1 = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), null) + .blockingGet(); + + InvocationContext ctx1 = + InvocationContext.builder() + .sessionService(sessionService) + .session(session1) + .agent(agent1) + .build(); + + ConcurrentHashMap stateDelta = new ConcurrentHashMap<>(); + stateDelta.put("user:language", "English"); + + Event event = + Event.builder() + .id(java.util.UUID.randomUUID().toString()) + .author(agent1.name()) + .content(response1) + .actions(com.google.adk.events.EventActions.builder().stateDelta(stateDelta).build()) + .build(); + + sessionService.appendEvent(session1, event).blockingGet(); + + Session session2 = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), null) + .blockingGet(); + + assertNotNull(session2); + assertEquals("English", session2.state().get("user:language")); + } + + @Test + public void testRegularStateIsolatedBetweenSessions() { + Session session1 = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), null) + .blockingGet(); + + ConcurrentHashMap stateDelta = new ConcurrentHashMap<>(); + stateDelta.put("session_data", "session1_value"); + + Event event = + Event.builder() + .id(java.util.UUID.randomUUID().toString()) + .author("test-agent") + .content(Content.fromParts(Part.fromText("Session 1 data"))) + .actions(com.google.adk.events.EventActions.builder().stateDelta(stateDelta).build()) + .build(); + + sessionService.appendEvent(session1, event).blockingGet(); + + assertEquals("session1_value", session1.state().get("session_data")); + + Session session2 = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), null) + .blockingGet(); + + assertNotNull(session2); + assertNull( + session2.state().get("session_data"), "Regular state should not persist across sessions"); + } + + @Test + public void testAppStatePersistedAcrossSessions() { + Session session1 = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), null) + .blockingGet(); + + ConcurrentHashMap stateDelta = new ConcurrentHashMap<>(); + stateDelta.put("app:api_key", "key-12345"); + + Event event = + Event.builder() + .id(java.util.UUID.randomUUID().toString()) + .author("test-agent") + .content(Content.fromParts(Part.fromText("App config stored"))) + .actions(com.google.adk.events.EventActions.builder().stateDelta(stateDelta).build()) + .build(); + + sessionService.appendEvent(session1, event).blockingGet(); + + Session session2 = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), null) + .blockingGet(); + + assertNotNull(session2); + assertEquals("key-12345", session2.state().get("app:api_key")); + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/MySQLFunctionalTest.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/MySQLFunctionalTest.java new file mode 100644 index 000000000..faf6fc22b --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/MySQLFunctionalTest.java @@ -0,0 +1,184 @@ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import com.google.adk.events.Event; +import com.google.adk.events.EventActions; +import com.google.adk.testing.TestDatabaseConfig; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import java.time.Instant; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; + +@Tag("integration") +public class MySQLFunctionalTest { + + private static final String TEST_DB_URL = TestDatabaseConfig.MYSQL_JDBC_URL; + private String TEST_APP_NAME; + private String TEST_USER_ID; + + private DatabaseSessionService sessionService; + + @BeforeEach + public void setUp() { + assumeTrue( + TestDatabaseConfig.isMySQLAvailable(), + TestDatabaseConfig.getDatabaseNotAvailableMessage("MySQL")); + + TEST_APP_NAME = "jdbc-mysql-test-app-" + System.currentTimeMillis(); + TEST_USER_ID = "jdbc-mysql-test-user-" + System.currentTimeMillis(); + + sessionService = new DatabaseSessionService(TEST_DB_URL); + } + + @AfterEach + public void tearDown() { + if (sessionService != null) { + sessionService.close(); + } + } + + @Test + public void testBasicSessionOperations() { + String sessionId = "mysql-basic-test-" + System.currentTimeMillis(); + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put("key", "value"); + + Session session = + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId).blockingGet(); + + assertNotNull(session); + assertEquals(sessionId, session.id()); + assertEquals("value", session.state().get("key")); + } + + @Test + public void testEventActionsWithStateDelta() { + String sessionId = "mysql-actions-test-" + System.currentTimeMillis(); + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + ConcurrentHashMap stateDelta = new ConcurrentHashMap<>(); + stateDelta.put("count", 1); + stateDelta.put("app:shared", "global"); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test-author") + .content(Content.fromParts(Part.fromText("Test event"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(EventActions.builder().stateDelta(stateDelta).build()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(retrieved); + assertEquals(1, retrieved.state().get("count")); + assertEquals("global", retrieved.state().get("app:shared")); + } + + @Test + public void testJSONStorageAndRetrieval() { + String sessionId = "mysql-json-test-" + System.currentTimeMillis(); + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put("nested", java.util.Map.of("inner", "value")); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId).blockingGet(); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(retrieved); + assertNotNull(retrieved.state().get("nested")); + } + + @Test + public void testUpsertAppState() { + String sessionId1 = "mysql-upsert-1-" + System.currentTimeMillis(); + String sessionId2 = "mysql-upsert-2-" + System.currentTimeMillis(); + + ConcurrentHashMap state1 = new ConcurrentHashMap<>(); + state1.put("app:config", "value1"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state1, sessionId1).blockingGet(); + + ConcurrentHashMap state2 = new ConcurrentHashMap<>(); + state2.put("app:config", "value2"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state2, sessionId2).blockingGet(); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId1, Optional.empty()) + .blockingGet(); + + assertNotNull(retrieved); + assertEquals("value2", retrieved.state().get("app:config")); + } + + @Test + public void testGetSessionWithInvalidConfig() { + String sessionId = "mysql-invalid-config-test-" + System.currentTimeMillis(); + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + // Add 5 events + for (int i = 1; i <= 5; i++) { + Event event = + Event.builder() + .id("event-" + i) + .author("test-author") + .content(Content.fromParts(Part.fromText("Event " + i))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + } + + // Test negative numRecentEvents: -1 should be treated as abs(-1) = 1 (last 1 event) + GetSessionConfig negativeNumEvents = GetSessionConfig.builder().numRecentEvents(-1).build(); + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.of(negativeNumEvents)) + .blockingGet(); + + assertNotNull(session); + // Should return exactly 1 event (the most recent one) + assertEquals( + 1, + session.events().size(), + "Expected 1 event for numRecentEvents=-1, got " + session.events().size()); + // Should be the last event added (event-5) + assertEquals( + "event-5", + session.events().get(0).id(), + "Expected most recent event (event-5), got " + session.events().get(0).id()); + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/MySQLIntegrationTest.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/MySQLIntegrationTest.java new file mode 100644 index 000000000..43a5acb9b --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/MySQLIntegrationTest.java @@ -0,0 +1,809 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import com.google.adk.events.Event; +import com.google.adk.events.EventActions; +import com.google.adk.testing.TestDatabaseConfig; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Timestamp; +import java.time.Instant; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; + +/** + * Integration tests that verify MySQL database table operations directly. + * + *

This test suite validates: - sessions table: CRUD operations, timestamps, state storage - + * events table: event persistence, foreign key relationships, cascading deletes - app_states table: + * application-wide state management - user_states table: user-specific state management - JSON + * storage and retrieval - Timestamp tracking (create_time, update_time) + * + *

Prerequisites: Start MySQL test database with: + * + *

{@code
+ * docker run -d -p 3306:3306 \
+ *   -e MYSQL_DATABASE=adk_test \
+ *   -e MYSQL_USER=adk_user \
+ *   -e MYSQL_PASSWORD=adk_password \
+ *   -e MYSQL_ROOT_PASSWORD=root_password \
+ *   mysql:8.0
+ * }
+ */ +@Tag("integration") +public class MySQLIntegrationTest { + + private static final String TEST_DB_URL = TestDatabaseConfig.MYSQL_JDBC_URL; + private static final String TEST_APP_NAME = "mysql-table-test-app"; + private static final String TEST_USER_ID = "mysql-table-test-user"; + + private DatabaseSessionService sessionService; + + @BeforeEach + public void setUp() { + assumeTrue( + TestDatabaseConfig.isMySQLAvailable(), + TestDatabaseConfig.getDatabaseNotAvailableMessage("MySQL")); + + sessionService = new DatabaseSessionService(TEST_DB_URL); + } + + @AfterEach + public void tearDown() { + if (sessionService != null) { + sessionService.close(); + } + } + + // ==================== SESSIONS TABLE TESTS ==================== + + @Test + public void testSessionsTableCreation() throws SQLException { + String sessionId = "session-" + System.currentTimeMillis(); + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put("test_key", "test_value"); + + // Create session via service + Session session = + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId).blockingGet(); + + // Verify in database directly + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = + "SELECT app_name, user_id, id, state, create_time, update_time " + + "FROM sessions WHERE app_name = ? AND user_id = ? AND id = ?"; + + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, TEST_APP_NAME); + stmt.setString(2, TEST_USER_ID); + stmt.setString(3, sessionId); + + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next(), "Session should exist in sessions table"); + + assertEquals(TEST_APP_NAME, rs.getString("app_name")); + assertEquals(TEST_USER_ID, rs.getString("user_id")); + assertEquals(sessionId, rs.getString("id")); + + // Verify JSON state + String stateJson = rs.getString("state"); + assertNotNull(stateJson); + assertTrue(stateJson.contains("test_key")); + assertTrue(stateJson.contains("test_value")); + + // Verify timestamps + Timestamp createTime = rs.getTimestamp("create_time"); + Timestamp updateTime = rs.getTimestamp("update_time"); + assertNotNull(createTime, "create_time should not be null"); + assertNotNull(updateTime, "update_time should not be null"); + + assertFalse(rs.next(), "Should only have one session with this ID"); + } + } + } + } + + @Test + public void testSessionsTableUpdate() throws SQLException { + String sessionId = "session-update-" + System.currentTimeMillis(); + + // Create initial session + Session session = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + Timestamp initialUpdateTime; + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = "SELECT update_time FROM sessions WHERE id = ?"; + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, sessionId); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next()); + initialUpdateTime = rs.getTimestamp("update_time"); + } + } + } + + // Wait a bit to ensure timestamp difference + try { + Thread.sleep(100); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + + // Update session by appending event + ConcurrentHashMap stateDelta = new ConcurrentHashMap<>(); + stateDelta.put("updated_key", "updated_value"); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test-author") + .content(Content.fromParts(Part.fromText("Update event"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(EventActions.builder().stateDelta(stateDelta).build()) + .build(); + + sessionService.appendEvent(session, event).blockingGet(); + + // Verify update_time changed + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = "SELECT state, update_time FROM sessions WHERE id = ?"; + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, sessionId); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next()); + + String stateJson = rs.getString("state"); + assertTrue(stateJson.contains("updated_key")); + assertTrue(stateJson.contains("updated_value")); + + Timestamp newUpdateTime = rs.getTimestamp("update_time"); + assertTrue( + newUpdateTime.after(initialUpdateTime), + "update_time should be updated after state change"); + } + } + } + } + + @Test + public void testSessionsTablePrimaryKey() throws SQLException { + String sessionId = "session-pk-" + System.currentTimeMillis(); + + // Create session + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + // Verify primary key constraint (app_name, user_id, id) + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + // Count sessions with this ID + String query = "SELECT COUNT(*) FROM sessions WHERE id = ?"; + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, sessionId); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next()); + assertEquals(1, rs.getInt(1), "Should have exactly one session with this ID"); + } + } + } + } + + // ==================== EVENTS TABLE TESTS ==================== + + @Test + public void testEventsTableCreation() throws SQLException { + String sessionId = "session-events-" + System.currentTimeMillis(); + + // Create session + Session session = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + // Create event + String eventId = UUID.randomUUID().toString(); + String invocationId = "inv-" + UUID.randomUUID().toString(); + long timestamp = Instant.now().toEpochMilli(); + + Event event = + Event.builder() + .id(eventId) + .invocationId(invocationId) + .author("test-author") + .content(Content.fromParts(Part.fromText("Test event content"))) + .timestamp(timestamp) + .build(); + + sessionService.appendEvent(session, event).blockingGet(); + + // Verify in database + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = + "SELECT id, app_name, user_id, session_id, invocation_id, timestamp, event_data " + + "FROM events WHERE id = ?"; + + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, eventId); + + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next(), "Event should exist in events table"); + + assertEquals(eventId, rs.getString("id")); + assertEquals(TEST_APP_NAME, rs.getString("app_name")); + assertEquals(TEST_USER_ID, rs.getString("user_id")); + assertEquals(sessionId, rs.getString("session_id")); + assertEquals(invocationId, rs.getString("invocation_id")); + + Timestamp eventTimestamp = rs.getTimestamp("timestamp"); + assertNotNull(eventTimestamp); + + // Verify JSON event_data + String eventData = rs.getString("event_data"); + assertNotNull(eventData); + assertTrue(eventData.contains("Test event content")); + assertTrue(eventData.contains("test-author")); + + assertFalse(rs.next(), "Should only have one event with this ID"); + } + } + } + } + + @Test + public void testEventsTableMultipleEvents() throws SQLException { + String sessionId = "session-multi-events-" + System.currentTimeMillis(); + + // Create session + Session session = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + // Create multiple events + int numEvents = 5; + for (int i = 0; i < numEvents; i++) { + Event event = + Event.builder() + .id("event-" + i + "-" + UUID.randomUUID()) + .author("author-" + i) + .content(Content.fromParts(Part.fromText("Event " + i))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, event).blockingGet(); + } + + // Verify event count in database + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = + "SELECT COUNT(*) FROM events WHERE app_name = ? AND user_id = ? AND session_id = ?"; + + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, TEST_APP_NAME); + stmt.setString(2, TEST_USER_ID); + stmt.setString(3, sessionId); + + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next()); + assertEquals(numEvents, rs.getInt(1), "Should have " + numEvents + " events"); + } + } + } + } + + @Test + public void testEventsTableForeignKeyConstraint() throws SQLException { + String sessionId = "session-fk-" + System.currentTimeMillis(); + + // Create session + Session session = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + // Add event + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test-author") + .content(Content.fromParts(Part.fromText("FK test event"))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + sessionService.appendEvent(session, event).blockingGet(); + + // Verify events exist + int eventCountBefore; + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = "SELECT COUNT(*) FROM events WHERE session_id = ?"; + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, sessionId); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next()); + eventCountBefore = rs.getInt(1); + assertTrue(eventCountBefore > 0, "Should have at least one event"); + } + } + } + + // Delete session (should cascade delete events) + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String deleteQuery = "DELETE FROM sessions WHERE app_name = ? AND user_id = ? AND id = ?"; + try (PreparedStatement stmt = conn.prepareStatement(deleteQuery)) { + stmt.setString(1, TEST_APP_NAME); + stmt.setString(2, TEST_USER_ID); + stmt.setString(3, sessionId); + stmt.executeUpdate(); + } + } + + // Verify events were cascade deleted + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = "SELECT COUNT(*) FROM events WHERE session_id = ?"; + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, sessionId); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next()); + assertEquals(0, rs.getInt(1), "Events should be cascade deleted with session"); + } + } + } + } + + @Test + public void testEventsTableTimestampOrdering() throws SQLException { + String sessionId = "session-timestamp-" + System.currentTimeMillis(); + + // Create session + Session session = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + // Create events with specific timestamps + long baseTime = Instant.now().toEpochMilli(); + for (int i = 0; i < 3; i++) { + Event event = + Event.builder() + .id("event-" + i + "-" + UUID.randomUUID()) + .author("author") + .content(Content.fromParts(Part.fromText("Event " + i))) + .timestamp(baseTime + (i * 1000)) + .build(); + + session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, event).blockingGet(); + } + + // Verify events are ordered by timestamp + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + // MySQL version: Use ->>'$.id' syntax instead of ->>'id' + String query = + "SELECT event_data->>'$.id' as event_id, timestamp " + + "FROM events WHERE session_id = ? ORDER BY timestamp ASC"; + + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, sessionId); + + try (ResultSet rs = stmt.executeQuery()) { + Timestamp previousTimestamp = null; + int count = 0; + + while (rs.next()) { + Timestamp currentTimestamp = rs.getTimestamp("timestamp"); + assertNotNull(currentTimestamp); + + if (previousTimestamp != null) { + assertTrue( + currentTimestamp.after(previousTimestamp) + || currentTimestamp.equals(previousTimestamp), + "Events should be ordered by timestamp"); + } + + previousTimestamp = currentTimestamp; + count++; + } + + assertEquals(3, count, "Should have 3 events"); + } + } + } + } + + // ==================== APP_STATES TABLE TESTS ==================== + + @Test + public void testAppStatesTableCreation() throws SQLException { + String sessionId = "session-app-state-" + System.currentTimeMillis(); + + // Create session with app state + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put("app:api_endpoint", "https://api.example.com"); + state.put("app:version", "2.0"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId).blockingGet(); + + // Verify in app_states table + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = "SELECT app_name, state, update_time FROM app_states WHERE app_name = ?"; + + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, TEST_APP_NAME); + + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next(), "App state should exist in app_states table"); + + assertEquals(TEST_APP_NAME, rs.getString("app_name")); + + String stateJson = rs.getString("state"); + assertNotNull(stateJson); + assertTrue(stateJson.contains("api_endpoint")); + assertTrue(stateJson.contains("https://api.example.com")); + assertTrue(stateJson.contains("version")); + assertTrue(stateJson.contains("2.0")); + + Timestamp updateTime = rs.getTimestamp("update_time"); + assertNotNull(updateTime, "update_time should not be null"); + } + } + } + } + + @Test + public void testAppStatesTableUpsert() throws SQLException { + String sessionId1 = "session-app-upsert-1-" + System.currentTimeMillis(); + String sessionId2 = "session-app-upsert-2-" + System.currentTimeMillis(); + + // Create first session with app state + ConcurrentHashMap state1 = new ConcurrentHashMap<>(); + state1.put("app:config", "version1"); + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state1, sessionId1).blockingGet(); + + Timestamp firstUpdateTime; + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = "SELECT update_time FROM app_states WHERE app_name = ?"; + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, TEST_APP_NAME); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next()); + firstUpdateTime = rs.getTimestamp("update_time"); + } + } + } + + // Wait to ensure timestamp difference + try { + Thread.sleep(100); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + + // Create second session with updated app state + ConcurrentHashMap state2 = new ConcurrentHashMap<>(); + state2.put("app:config", "version2"); + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state2, sessionId2).blockingGet(); + + // Verify app state was updated (not duplicated) + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = + "SELECT COUNT(*), state, update_time FROM app_states WHERE app_name = ? GROUP BY state, update_time"; + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, TEST_APP_NAME); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next(), "Should have app state"); + assertEquals(1, rs.getInt(1), "Should have only one row for this app"); + + String stateJson = rs.getString("state"); + assertTrue(stateJson.contains("version2"), "State should be updated to version2"); + + Timestamp newUpdateTime = rs.getTimestamp("update_time"); + assertTrue( + newUpdateTime.after(firstUpdateTime) || newUpdateTime.equals(firstUpdateTime), + "update_time should be updated"); + + assertFalse(rs.next(), "Should only have one app state row"); + } + } + } + } + + @Test + public void testAppStatesTablePrimaryKey() throws SQLException { + String sessionId = "session-app-pk-" + System.currentTimeMillis(); + + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put("app:key", "value"); + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId).blockingGet(); + + // Verify only one row per app_name + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = "SELECT COUNT(*) FROM app_states WHERE app_name = ?"; + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, TEST_APP_NAME); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next()); + assertEquals(1, rs.getInt(1), "Should have exactly one row per app_name"); + } + } + } + } + + // ==================== USER_STATES TABLE TESTS ==================== + + @Test + public void testUserStatesTableCreation() throws SQLException { + String sessionId = "session-user-state-" + System.currentTimeMillis(); + + // Create session with user state + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put("user:language", "English"); + state.put("user:timezone", "UTC"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId).blockingGet(); + + // Verify in user_states table + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = + "SELECT app_name, user_id, state, update_time FROM user_states WHERE app_name = ? AND user_id = ?"; + + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, TEST_APP_NAME); + stmt.setString(2, TEST_USER_ID); + + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next(), "User state should exist in user_states table"); + + assertEquals(TEST_APP_NAME, rs.getString("app_name")); + assertEquals(TEST_USER_ID, rs.getString("user_id")); + + String stateJson = rs.getString("state"); + assertNotNull(stateJson); + assertTrue(stateJson.contains("language")); + assertTrue(stateJson.contains("English")); + assertTrue(stateJson.contains("timezone")); + assertTrue(stateJson.contains("UTC")); + + Timestamp updateTime = rs.getTimestamp("update_time"); + assertNotNull(updateTime, "update_time should not be null"); + } + } + } + } + + @Test + public void testUserStatesTableUpsert() throws SQLException { + String sessionId1 = "session-user-upsert-1-" + System.currentTimeMillis(); + String sessionId2 = "session-user-upsert-2-" + System.currentTimeMillis(); + + // Create first session with user state + ConcurrentHashMap state1 = new ConcurrentHashMap<>(); + state1.put("user:preference", "dark"); + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state1, sessionId1).blockingGet(); + + Timestamp firstUpdateTime; + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = "SELECT update_time FROM user_states WHERE app_name = ? AND user_id = ?"; + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, TEST_APP_NAME); + stmt.setString(2, TEST_USER_ID); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next()); + firstUpdateTime = rs.getTimestamp("update_time"); + } + } + } + + // Wait to ensure timestamp difference + try { + Thread.sleep(100); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + + // Create second session with updated user state + ConcurrentHashMap state2 = new ConcurrentHashMap<>(); + state2.put("user:preference", "light"); + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state2, sessionId2).blockingGet(); + + // Verify user state was updated (not duplicated) + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = + "SELECT COUNT(*), state, update_time FROM user_states " + + "WHERE app_name = ? AND user_id = ? GROUP BY state, update_time"; + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, TEST_APP_NAME); + stmt.setString(2, TEST_USER_ID); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next(), "Should have user state"); + assertEquals(1, rs.getInt(1), "Should have only one row for this user"); + + String stateJson = rs.getString("state"); + assertTrue(stateJson.contains("light"), "State should be updated to 'light'"); + + Timestamp newUpdateTime = rs.getTimestamp("update_time"); + assertTrue( + newUpdateTime.after(firstUpdateTime) || newUpdateTime.equals(firstUpdateTime), + "update_time should be updated"); + + assertFalse(rs.next(), "Should only have one user state row"); + } + } + } + } + + @Test + public void testUserStatesTablePrimaryKey() throws SQLException { + String sessionId = "session-user-pk-" + System.currentTimeMillis(); + + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put("user:key", "value"); + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId).blockingGet(); + + // Verify only one row per (app_name, user_id) combination + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = "SELECT COUNT(*) FROM user_states WHERE app_name = ? AND user_id = ?"; + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, TEST_APP_NAME); + stmt.setString(2, TEST_USER_ID); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next()); + assertEquals(1, rs.getInt(1), "Should have exactly one row per (app_name, user_id)"); + } + } + } + } + + @Test + public void testUserStatesTableIsolationBetweenUsers() throws SQLException { + String user1 = "user-1-" + System.currentTimeMillis(); + String user2 = "user-2-" + System.currentTimeMillis(); + + // Create sessions for two different users + ConcurrentHashMap state1 = new ConcurrentHashMap<>(); + state1.put("user:language", "French"); + sessionService.createSession(TEST_APP_NAME, user1, state1, "session-1").blockingGet(); + + ConcurrentHashMap state2 = new ConcurrentHashMap<>(); + state2.put("user:language", "Spanish"); + sessionService.createSession(TEST_APP_NAME, user2, state2, "session-2").blockingGet(); + + // Verify both users have separate state + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = + "SELECT user_id, state FROM user_states WHERE app_name = ? AND user_id IN (?, ?)"; + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, TEST_APP_NAME); + stmt.setString(2, user1); + stmt.setString(3, user2); + try (ResultSet rs = stmt.executeQuery()) { + int count = 0; + while (rs.next()) { + String userId = rs.getString("user_id"); + String stateJson = rs.getString("state"); + + if (userId.equals(user1)) { + assertTrue(stateJson.contains("French"), "User 1 should have French"); + } else if (userId.equals(user2)) { + assertTrue(stateJson.contains("Spanish"), "User 2 should have Spanish"); + } + count++; + } + assertEquals(2, count, "Should have state for both users"); + } + } + } + } + + // ==================== CROSS-TABLE INTEGRATION TESTS ==================== + + @Test + public void testAllTablesIntegration() throws SQLException { + String sessionId = "session-integration-" + System.currentTimeMillis(); + + // Create session with all state types + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put("session_key", "session_value"); // session state + state.put("app:api_key", "app-12345"); // app state + state.put("user:theme", "dark"); // user state + + Session session = + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId).blockingGet(); + + // Add event + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("integration-test") + .content(Content.fromParts(Part.fromText("Integration test event"))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + sessionService.appendEvent(session, event).blockingGet(); + + // Verify all tables have data + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + // Check sessions table + String sessionQuery = "SELECT COUNT(*) FROM sessions WHERE id = ?"; + try (PreparedStatement stmt = conn.prepareStatement(sessionQuery)) { + stmt.setString(1, sessionId); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next()); + assertEquals(1, rs.getInt(1), "Session should exist"); + } + } + + // Check events table + String eventQuery = "SELECT COUNT(*) FROM events WHERE session_id = ?"; + try (PreparedStatement stmt = conn.prepareStatement(eventQuery)) { + stmt.setString(1, sessionId); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next()); + assertTrue(rs.getInt(1) > 0, "Should have at least one event"); + } + } + + // Check app_states table + String appQuery = "SELECT state FROM app_states WHERE app_name = ?"; + try (PreparedStatement stmt = conn.prepareStatement(appQuery)) { + stmt.setString(1, TEST_APP_NAME); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next(), "App state should exist"); + String appState = rs.getString("state"); + assertTrue(appState.contains("api_key"), "App state should contain api_key"); + } + } + + // Check user_states table + String userQuery = "SELECT state FROM user_states WHERE app_name = ? AND user_id = ?"; + try (PreparedStatement stmt = conn.prepareStatement(userQuery)) { + stmt.setString(1, TEST_APP_NAME); + stmt.setString(2, TEST_USER_ID); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next(), "User state should exist"); + String userState = rs.getString("state"); + assertTrue(userState.contains("theme"), "User state should contain theme"); + } + } + } + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/NegativeTestCases.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/NegativeTestCases.java new file mode 100644 index 000000000..6a1cd6c7e --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/NegativeTestCases.java @@ -0,0 +1,410 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.google.adk.events.Event; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import org.flywaydb.core.Flyway; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class NegativeTestCases { + + private static final String TEST_DB_URL = + "jdbc:h2:mem:negative_test;DB_CLOSE_DELAY=-1;USER=sa;PASSWORD="; + private static final String TEST_APP_NAME = "negative-test-app"; + private static final String TEST_USER_ID = "negative-test-user"; + + private DatabaseSessionService sessionService; + + @BeforeEach + public void setUp() { + Flyway flyway = + Flyway.configure() + .dataSource(TEST_DB_URL, null, null) + .locations("classpath:db/migration/h2") + .cleanDisabled(false) + .load(); + flyway.clean(); + flyway.migrate(); + + sessionService = new DatabaseSessionService(TEST_DB_URL); + } + + @AfterEach + public void tearDown() { + if (sessionService != null) { + sessionService.close(); + } + } + + @Test + public void testCreateSessionWithNullAppName() { + assertThrows( + NullPointerException.class, + () -> + sessionService + .createSession(null, TEST_USER_ID, new ConcurrentHashMap<>(), "session-1") + .blockingGet()); + } + + @Test + public void testCreateSessionWithNullUserId() { + assertThrows( + NullPointerException.class, + () -> + sessionService + .createSession(TEST_APP_NAME, null, new ConcurrentHashMap<>(), "session-1") + .blockingGet()); + } + + @Test + public void testCreateSessionWithNullState() { + // Null state should be accepted and treated as empty state + Session session = + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, null, "session-1").blockingGet(); + assertNotNull(session); + assertNotNull(session.state()); + assertTrue(session.state().isEmpty()); + } + + @Test + public void testGetSessionWithNullAppName() { + assertThrows( + NullPointerException.class, + () -> + sessionService + .getSession(null, TEST_USER_ID, "session-1", Optional.empty()) + .blockingGet()); + } + + @Test + public void testGetSessionWithNullUserId() { + assertThrows( + NullPointerException.class, + () -> + sessionService + .getSession(TEST_APP_NAME, null, "session-1", Optional.empty()) + .blockingGet()); + } + + @Test + public void testGetSessionWithNullSessionId() { + assertThrows( + NullPointerException.class, + () -> + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, null, Optional.empty()) + .blockingGet()); + } + + @Test + public void testAppendEventToDeletedSession() { + String sessionId = "deleted-session"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + sessionService.deleteSession(TEST_APP_NAME, TEST_USER_ID, sessionId).blockingAwait(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test-author") + .content(Content.fromParts(Part.fromText("Test"))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + assertThrows( + SessionNotFoundException.class, + () -> + sessionService + .appendEvent( + Session.builder(sessionId) + .appName(TEST_APP_NAME) + .userId(TEST_USER_ID) + .state(new ConcurrentHashMap<>()) + .events(new ArrayList<>()) + .build(), + event) + .blockingGet()); + } + + @Test + public void testDeleteNonExistentSession() { + sessionService.deleteSession(TEST_APP_NAME, TEST_USER_ID, "non-existent").blockingAwait(); + } + + @Test + public void testGetNonExistentSession() { + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, "non-existent", Optional.empty()) + .blockingGet(); + assertNull(session); + } + + @Test + public void testCreateSessionWithEmptySessionId() { + Session session = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), "") + .blockingGet(); + assertNotNull(session); + } + + @Test + public void testCreateSessionWithVeryLongSessionId() { + String longId = "a".repeat(200); + + Exception exception = + assertThrows( + Exception.class, + () -> + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), longId) + .blockingGet()); + + assertTrue( + exception.getMessage().contains("too long") || exception.getCause() != null, + "Should fail with constraint violation for long session ID"); + } + + @Test + public void testCreateDuplicateSession() { + String sessionId = "duplicate-session"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + assertThrows( + Exception.class, + () -> + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet()); + } + + @Test + public void testAppendEventWithDuplicateId() { + String sessionId = "duplicate-event-test"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + String eventId = "duplicate-event-id"; + Event event1 = + Event.builder() + .id(eventId) + .author("author-1") + .content(Content.fromParts(Part.fromText("Event 1"))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event1).blockingGet(); + + Event event2 = + Event.builder() + .id(eventId) + .author("author-2") + .content(Content.fromParts(Part.fromText("Event 2"))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + assertThrows( + Exception.class, + () -> + sessionService + .appendEvent( + Session.builder(sessionId) + .appName(TEST_APP_NAME) + .userId(TEST_USER_ID) + .state(new ConcurrentHashMap<>()) + .events(new ArrayList<>()) + .build(), + event2) + .blockingGet()); + } + + @Test + public void testStateDeltaWithComplexNestedStructures() { + String sessionId = "complex-state-test"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + ConcurrentHashMap complexDelta = new ConcurrentHashMap<>(); + complexDelta.put("level1", Map.of("level2", Map.of("level3", "deep-value"))); + complexDelta.put("array", java.util.List.of(1, 2, 3, 4, 5)); + complexDelta.put("mixed", Map.of("num", 42, "str", "text", "bool", true)); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test-author") + .content(Content.fromParts(Part.fromText("Complex state"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(com.google.adk.events.EventActions.builder().stateDelta(complexDelta).build()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + + Session updatedSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + assertNotNull(updatedSession); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + assertNotNull(retrieved); + assertTrue(retrieved.state().containsKey("level1")); + assertTrue(retrieved.state().containsKey("array")); + assertTrue(retrieved.state().containsKey("mixed")); + } + + @Test + public void testGetSessionWithInvalidConfig() { + String sessionId = "invalid-config-test"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + for (int i = 1; i <= 5; i++) { + Event event = + Event.builder() + .id("event-" + i) + .author("test-author") + .content(Content.fromParts(Part.fromText("Event " + i))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + } + + // Test negative numRecentEvents: -1 should be treated as abs(-1) = 1 (last 1 event) + GetSessionConfig negativeNumEvents = GetSessionConfig.builder().numRecentEvents(-1).build(); + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.of(negativeNumEvents)) + .blockingGet(); + + assertNotNull(session); + // Should return exactly 1 event (the most recent one) + assertTrue( + session.events().size() == 1, + "Expected 1 event for numRecentEvents=-1, got " + session.events().size()); + // Should be the last event added (event-5) + assertTrue( + session.events().get(0).id().equals("event-5"), + "Expected most recent event (event-5), got " + session.events().get(0).id()); + } + + @Test + public void testConcurrentDeleteAndRead() throws InterruptedException { + String sessionId = "concurrent-delete-read-test"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + for (int i = 1; i <= 5; i++) { + Event event = + Event.builder() + .id("event-" + i) + .author("test-author") + .content(Content.fromParts(Part.fromText("Event " + i))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + } + + Thread deleter = + new Thread( + () -> { + try { + Thread.sleep(50); + sessionService + .deleteSession(TEST_APP_NAME, TEST_USER_ID, sessionId) + .blockingAwait(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + }); + + Thread reader = + new Thread( + () -> { + for (int i = 0; i < 10; i++) { + try { + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + Thread.sleep(20); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + break; + } catch (Exception e) { + } + } + }); + + deleter.start(); + reader.start(); + + deleter.join(); + reader.join(); + + Session finalCheck = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + assertNull(finalCheck, "Session should be deleted"); + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/PessimisticLockingTest.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/PessimisticLockingTest.java new file mode 100644 index 000000000..f19461d2e --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/PessimisticLockingTest.java @@ -0,0 +1,340 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.google.adk.events.Event; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import java.time.Instant; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import org.flywaydb.core.Flyway; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class PessimisticLockingTest { + + private static final String TEST_DB_URL = + "jdbc:h2:mem:locking_test;DB_CLOSE_DELAY=-1;USER=sa;PASSWORD="; + private static final String TEST_APP_NAME = "locking-test-app"; + private static final String TEST_USER_ID = "locking-user"; + + private DatabaseSessionService sessionService; + + @BeforeEach + public void setUp() { + Flyway flyway = + Flyway.configure() + .dataSource(TEST_DB_URL, null, null) + .locations("classpath:db/migration/h2") + .cleanDisabled(false) + .load(); + flyway.clean(); + flyway.migrate(); + + sessionService = new DatabaseSessionService(TEST_DB_URL); + } + + @AfterEach + public void tearDown() { + if (sessionService != null) { + sessionService.close(); + } + } + + @Test + public void testSerializedEventAppends() throws InterruptedException { + String sessionId = "serialized-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + int threadCount = 20; + ExecutorService executor = Executors.newFixedThreadPool(threadCount); + CountDownLatch latch = new CountDownLatch(threadCount); + AtomicInteger successCount = new AtomicInteger(0); + + for (int i = 0; i < threadCount; i++) { + final int eventNum = i; + executor.submit( + () -> { + try { + Event event = + Event.builder() + .id("event-" + eventNum) + .author("thread-" + eventNum) + .content(Content.fromParts(Part.fromText("Message " + eventNum))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + successCount.incrementAndGet(); + } catch (Exception e) { + throw new RuntimeException(e); + } finally { + latch.countDown(); + } + }); + } + + assertTrue(latch.await(60, TimeUnit.SECONDS)); + executor.shutdown(); + + assertEquals(threadCount, successCount.get()); + + Session finalSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + assertNotNull(finalSession); + assertEquals(threadCount, finalSession.events().size()); + } + + @Test + public void testNoLostUpdates() throws InterruptedException { + String sessionId = "no-lost-updates"; + + ConcurrentHashMap initialState = new ConcurrentHashMap<>(); + initialState.put("counter", 0); + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, initialState, sessionId) + .blockingGet(); + + int threadCount = 10; + ExecutorService executor = Executors.newFixedThreadPool(threadCount); + CountDownLatch latch = new CountDownLatch(threadCount); + + for (int i = 0; i < threadCount; i++) { + final int eventNum = i; + executor.submit( + () -> { + try { + Event event = + Event.builder() + .id("event-" + eventNum) + .author("thread-" + eventNum) + .content(Content.fromParts(Part.fromText("Message " + eventNum))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + } catch (Exception e) { + throw new RuntimeException(e); + } finally { + latch.countDown(); + } + }); + } + + assertTrue(latch.await(60, TimeUnit.SECONDS)); + executor.shutdown(); + + Session finalSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + assertNotNull(finalSession); + assertEquals(threadCount, finalSession.events().size()); + } + + @Test + public void testConcurrentAppendDifferentSessions() throws InterruptedException { + int sessionCount = 5; + int eventsPerSession = 10; + + for (int i = 0; i < sessionCount; i++) { + String sessionId = "session-" + i; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + } + + ExecutorService executor = Executors.newFixedThreadPool(sessionCount * eventsPerSession); + CountDownLatch latch = new CountDownLatch(sessionCount * eventsPerSession); + + for (int i = 0; i < sessionCount; i++) { + final String sessionId = "session-" + i; + for (int j = 0; j < eventsPerSession; j++) { + final int eventNum = j; + executor.submit( + () -> { + try { + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Event " + eventNum))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + } catch (Exception e) { + throw new RuntimeException(e); + } finally { + latch.countDown(); + } + }); + } + } + + assertTrue(latch.await(60, TimeUnit.SECONDS)); + executor.shutdown(); + + for (int i = 0; i < sessionCount; i++) { + String sessionId = "session-" + i; + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + assertNotNull(session); + assertEquals(eventsPerSession, session.events().size()); + } + } + + @Test + public void testAppendEventUnderLoad() throws InterruptedException { + String sessionId = "load-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + int threadCount = 50; + ExecutorService executor = Executors.newFixedThreadPool(threadCount); + CountDownLatch latch = new CountDownLatch(threadCount); + AtomicInteger successCount = new AtomicInteger(0); + AtomicInteger failureCount = new AtomicInteger(0); + + for (int i = 0; i < threadCount; i++) { + final int eventNum = i; + executor.submit( + () -> { + try { + Event event = + Event.builder() + .id("event-" + eventNum) + .author("thread-" + eventNum) + .content(Content.fromParts(Part.fromText("Load test " + eventNum))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + successCount.incrementAndGet(); + } catch (Exception e) { + failureCount.incrementAndGet(); + throw new RuntimeException(e); + } finally { + latch.countDown(); + } + }); + } + + assertTrue(latch.await(120, TimeUnit.SECONDS)); + executor.shutdown(); + + assertEquals(threadCount, successCount.get()); + assertEquals(0, failureCount.get()); + + Session finalSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + assertNotNull(finalSession); + assertEquals(threadCount, finalSession.events().size()); + } + + @Test + public void testEventOrderingConsistency() throws InterruptedException { + String sessionId = "ordering-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + int eventCount = 100; + ExecutorService executor = Executors.newFixedThreadPool(10); + CountDownLatch latch = new CountDownLatch(eventCount); + + for (int i = 0; i < eventCount; i++) { + final int eventNum = i; + executor.submit( + () -> { + try { + Event event = + Event.builder() + .id("event-" + String.format("%03d", eventNum)) + .author("test") + .content(Content.fromParts(Part.fromText("Message " + eventNum))) + .timestamp(Instant.now().toEpochMilli() + eventNum) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + } catch (Exception e) { + throw new RuntimeException(e); + } finally { + latch.countDown(); + } + }); + } + + assertTrue(latch.await(120, TimeUnit.SECONDS)); + executor.shutdown(); + + Session finalSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + assertNotNull(finalSession); + assertEquals(eventCount, finalSession.events().size()); + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/PostgreSQLAgentIntegrationTest.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/PostgreSQLAgentIntegrationTest.java new file mode 100644 index 000000000..7db4e3745 --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/PostgreSQLAgentIntegrationTest.java @@ -0,0 +1,364 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import com.google.adk.agents.InvocationContext; +import com.google.adk.agents.LlmAgent; +import com.google.adk.agents.SequentialAgent; +import com.google.adk.events.Event; +import com.google.adk.models.LlmResponse; +import com.google.adk.testing.TestDatabaseConfig; +import com.google.adk.testing.TestLlm; +import com.google.common.collect.ImmutableList; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; + +/** + * Integration tests for Agents using DatabaseSessionService with real PostgreSQL 16 database. + * + *

This test suite verifies that agents work correctly with PostgreSQL-backed session + * persistence, including: - Sequential agent execution with database persistence - State + * propagation between agents via outputKey - Event storage and retrieval - App/user/session state + * management with database backend - JSONB storage for complex state data + * + *

Prerequisites: Start PostgreSQL test database with: + * + *

{@code
+ * docker-compose -f scripts/docker-compose.test.yml up -d postgres-test
+ * }
+ * + *

Configuration: - Host: localhost:5433 - Database: adk_test - User: adk_user - Password: + * adk_password + */ +@Tag("integration") +public class PostgreSQLAgentIntegrationTest { + + private static final String TEST_DB_URL = TestDatabaseConfig.POSTGRES_JDBC_URL; + private static final String TEST_APP_NAME = "postgres-agent-integration-test"; + private static final String TEST_USER_ID = "agent-test-user"; + + private DatabaseSessionService sessionService; + + @BeforeEach + public void setUp() { + assumeTrue( + TestDatabaseConfig.isPostgreSQLAvailable(), + TestDatabaseConfig.getDatabaseNotAvailableMessage("PostgreSQL")); + + sessionService = new DatabaseSessionService(TEST_DB_URL); + } + + @AfterEach + public void tearDown() { + if (sessionService != null) { + sessionService.close(); + } + } + + @Test + public void testSequentialAgentWithDatabasePersistence() { + Content agentAResponse = Content.fromParts(Part.fromText("The topic is: Machine Learning")); + TestLlm llmA = + new TestLlm( + ImmutableList.of( + LlmResponse.builder() + .content(agentAResponse) + .partial(false) + .turnComplete(true) + .build())); + + LlmAgent agentA = + LlmAgent.builder() + .name("AgentA") + .model(llmA) + .instruction("Extract topic") + .outputKey("topic") + .build(); + + Content agentBResponse = + Content.fromParts(Part.fromText("Summary: Machine Learning is transformative")); + TestLlm llmB = + new TestLlm( + ImmutableList.of( + LlmResponse.builder() + .content(agentBResponse) + .partial(false) + .turnComplete(true) + .build())); + + LlmAgent agentB = + LlmAgent.builder() + .name("AgentB") + .model(llmB) + .instruction("Summarize topic: ${topic}") + .outputKey("summary") + .build(); + + SequentialAgent sequential = + SequentialAgent.builder() + .name("SequentialAgent") + .subAgents(ImmutableList.of(agentA, agentB)) + .build(); + + Session session = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), null) + .blockingGet(); + + InvocationContext ctx = + InvocationContext.builder() + .sessionService(sessionService) + .session(session) + .agent(sequential) + .build(); + + List events = + sequential + .runAsync(ctx) + .flatMap(event -> ctx.sessionService().appendEvent(ctx.session(), event).toFlowable()) + .toList() + .blockingGet(); + + assertNotNull(events); + assertTrue(events.size() >= 2, "Expected at least 2 events from sequential agents"); + + assertEquals("The topic is: Machine Learning", ctx.session().state().get("topic")); + assertEquals( + "Summary: Machine Learning is transformative", ctx.session().state().get("summary")); + + Session retrievedSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, session.id(), Optional.empty()) + .blockingGet(); + + assertNotNull(retrievedSession); + assertEquals(session.id(), retrievedSession.id()); + assertEquals("The topic is: Machine Learning", retrievedSession.state().get("topic")); + assertEquals( + "Summary: Machine Learning is transformative", retrievedSession.state().get("summary")); + assertTrue( + retrievedSession.events().size() >= 2, "Expected at least 2 events persisted in database"); + } + + @Test + public void testAgentWithAppAndUserStatePersistence() { + Content configResponse = + Content.fromParts( + Part.fromText("{\"version\": \"2.0\", \"feature_flags\": {\"new_ui\": true}}")); + TestLlm llmA = + new TestLlm( + ImmutableList.of( + LlmResponse.builder() + .content(configResponse) + .partial(false) + .turnComplete(true) + .build())); + + LlmAgent agentA = + LlmAgent.builder() + .name("ConfigAgent") + .model(llmA) + .instruction("Return config") + .outputKey("app:config") + .build(); + + Content prefResponse = Content.fromParts(Part.fromText("light")); + TestLlm llmB = + new TestLlm( + ImmutableList.of( + LlmResponse.builder() + .content(prefResponse) + .partial(false) + .turnComplete(true) + .build())); + + LlmAgent agentB = + LlmAgent.builder() + .name("PreferenceAgent") + .model(llmB) + .instruction("Use config: ${app:config}") + .outputKey("user:theme") + .build(); + + SequentialAgent sequential = + SequentialAgent.builder() + .name("Sequential") + .subAgents(ImmutableList.of(agentA, agentB)) + .build(); + + Session session = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), null) + .blockingGet(); + + InvocationContext ctx = + InvocationContext.builder() + .sessionService(sessionService) + .session(session) + .agent(sequential) + .build(); + + sequential + .runAsync(ctx) + .flatMap(event -> ctx.sessionService().appendEvent(ctx.session(), event).toFlowable()) + .toList() + .blockingGet(); + + assertEquals( + "{\"version\": \"2.0\", \"feature_flags\": {\"new_ui\": true}}", + ctx.session().state().get("app:config")); + assertEquals("light", ctx.session().state().get("user:theme")); + + Session retrievedSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, session.id(), Optional.empty()) + .blockingGet(); + + assertNotNull(retrievedSession); + assertEquals( + "{\"version\": \"2.0\", \"feature_flags\": {\"new_ui\": true}}", + retrievedSession.state().get("app:config")); + assertEquals("light", retrievedSession.state().get("user:theme")); + } + + @Test + public void testAgentStatePersistedAcrossSessions() { + Content response1 = Content.fromParts(Part.fromText("User preference stored")); + TestLlm llm1 = + new TestLlm( + ImmutableList.of( + LlmResponse.builder() + .content(response1) + .partial(false) + .turnComplete(true) + .build())); + + LlmAgent agent1 = + LlmAgent.builder() + .name("PreferenceAgent") + .model(llm1) + .instruction("Store preference") + .outputKey("user:language") + .build(); + + Session session1 = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), null) + .blockingGet(); + + InvocationContext ctx1 = + InvocationContext.builder() + .sessionService(sessionService) + .session(session1) + .agent(agent1) + .build(); + + ConcurrentHashMap stateDelta = new ConcurrentHashMap<>(); + stateDelta.put("user:language", "French"); + + Event event = + Event.builder() + .id(java.util.UUID.randomUUID().toString()) + .author(agent1.name()) + .content(response1) + .actions(com.google.adk.events.EventActions.builder().stateDelta(stateDelta).build()) + .build(); + + sessionService.appendEvent(session1, event).blockingGet(); + + Session session2 = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), null) + .blockingGet(); + + assertNotNull(session2); + assertEquals("French", session2.state().get("user:language")); + } + + @Test + public void testRegularStateIsolatedBetweenSessions() { + Session session1 = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), null) + .blockingGet(); + + ConcurrentHashMap stateDelta = new ConcurrentHashMap<>(); + stateDelta.put("session_data", "session1_value"); + + Event event = + Event.builder() + .id(java.util.UUID.randomUUID().toString()) + .author("test-agent") + .content(Content.fromParts(Part.fromText("Session 1 data"))) + .actions(com.google.adk.events.EventActions.builder().stateDelta(stateDelta).build()) + .build(); + + sessionService.appendEvent(session1, event).blockingGet(); + + assertEquals("session1_value", session1.state().get("session_data")); + + Session session2 = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), null) + .blockingGet(); + + assertNotNull(session2); + assertNull( + session2.state().get("session_data"), "Regular state should not persist across sessions"); + } + + @Test + public void testAppStatePersistedAcrossSessions() { + Session session1 = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), null) + .blockingGet(); + + ConcurrentHashMap stateDelta = new ConcurrentHashMap<>(); + stateDelta.put("app:api_key", "key-67890"); + + Event event = + Event.builder() + .id(java.util.UUID.randomUUID().toString()) + .author("test-agent") + .content(Content.fromParts(Part.fromText("App config stored"))) + .actions(com.google.adk.events.EventActions.builder().stateDelta(stateDelta).build()) + .build(); + + sessionService.appendEvent(session1, event).blockingGet(); + + Session session2 = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), null) + .blockingGet(); + + assertNotNull(session2); + assertEquals("key-67890", session2.state().get("app:api_key")); + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/PostgreSQLFunctionalTest.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/PostgreSQLFunctionalTest.java new file mode 100644 index 000000000..dc579d250 --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/PostgreSQLFunctionalTest.java @@ -0,0 +1,184 @@ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import com.google.adk.events.Event; +import com.google.adk.events.EventActions; +import com.google.adk.testing.TestDatabaseConfig; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import java.time.Instant; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; + +@Tag("integration") +public class PostgreSQLFunctionalTest { + + private static final String TEST_DB_URL = TestDatabaseConfig.POSTGRES_JDBC_URL; + private String TEST_APP_NAME; + private String TEST_USER_ID; + + private DatabaseSessionService sessionService; + + @BeforeEach + public void setUp() { + assumeTrue( + TestDatabaseConfig.isPostgreSQLAvailable(), + TestDatabaseConfig.getDatabaseNotAvailableMessage("PostgreSQL")); + + TEST_APP_NAME = "jdbc-postgres-test-app-" + System.currentTimeMillis(); + TEST_USER_ID = "jdbc-postgres-test-user-" + System.currentTimeMillis(); + + sessionService = new DatabaseSessionService(TEST_DB_URL); + } + + @AfterEach + public void tearDown() { + if (sessionService != null) { + sessionService.close(); + } + } + + @Test + public void testBasicSessionOperations() { + String sessionId = "postgres-basic-test-" + System.currentTimeMillis(); + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put("key", "value"); + + Session session = + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId).blockingGet(); + + assertNotNull(session); + assertEquals(sessionId, session.id()); + assertEquals("value", session.state().get("key")); + } + + @Test + public void testEventActionsWithStateDelta() { + String sessionId = "postgres-actions-test-" + System.currentTimeMillis(); + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + ConcurrentHashMap stateDelta = new ConcurrentHashMap<>(); + stateDelta.put("count", 1); + stateDelta.put("app:shared", "global"); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test-author") + .content(Content.fromParts(Part.fromText("Test event"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(EventActions.builder().stateDelta(stateDelta).build()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(retrieved); + assertEquals(1, retrieved.state().get("count")); + assertEquals("global", retrieved.state().get("app:shared")); + } + + @Test + public void testJSONBStorageAndRetrieval() { + String sessionId = "postgres-jsonb-test-" + System.currentTimeMillis(); + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put("nested", java.util.Map.of("inner", "value")); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId).blockingGet(); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(retrieved); + assertNotNull(retrieved.state().get("nested")); + } + + @Test + public void testUpsertAppState() { + String sessionId1 = "postgres-upsert-1-" + System.currentTimeMillis(); + String sessionId2 = "postgres-upsert-2-" + System.currentTimeMillis(); + + ConcurrentHashMap state1 = new ConcurrentHashMap<>(); + state1.put("app:config", "value1"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state1, sessionId1).blockingGet(); + + ConcurrentHashMap state2 = new ConcurrentHashMap<>(); + state2.put("app:config", "value2"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state2, sessionId2).blockingGet(); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId1, Optional.empty()) + .blockingGet(); + + assertNotNull(retrieved); + assertEquals("value2", retrieved.state().get("app:config")); + } + + @Test + public void testGetSessionWithInvalidConfig() { + String sessionId = "postgres-invalid-config-test-" + System.currentTimeMillis(); + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + // Add 5 events + for (int i = 1; i <= 5; i++) { + Event event = + Event.builder() + .id("event-" + i) + .author("test-author") + .content(Content.fromParts(Part.fromText("Event " + i))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + } + + // Test negative numRecentEvents: -1 should be treated as abs(-1) = 1 (last 1 event) + GetSessionConfig negativeNumEvents = GetSessionConfig.builder().numRecentEvents(-1).build(); + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.of(negativeNumEvents)) + .blockingGet(); + + assertNotNull(session); + // Should return exactly 1 event (the most recent one) + assertEquals( + 1, + session.events().size(), + "Expected 1 event for numRecentEvents=-1, got " + session.events().size()); + // Should be the last event added (event-5) + assertEquals( + "event-5", + session.events().get(0).id(), + "Expected most recent event (event-5), got " + session.events().get(0).id()); + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/PostgreSQLIntegrationTest.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/PostgreSQLIntegrationTest.java new file mode 100644 index 000000000..2b624f78e --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/PostgreSQLIntegrationTest.java @@ -0,0 +1,807 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import com.google.adk.events.Event; +import com.google.adk.events.EventActions; +import com.google.adk.testing.TestDatabaseConfig; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Timestamp; +import java.time.Instant; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; + +/** + * Integration tests that verify PostgreSQL database table operations directly. + * + *

This test suite validates: - sessions table: CRUD operations, timestamps, state storage - + * events table: event persistence, foreign key relationships, cascading deletes - app_states table: + * application-wide state management - user_states table: user-specific state management - JSONB + * storage and retrieval - Timestamp tracking (create_time, update_time) + * + *

Prerequisites: Start PostgreSQL test database with: + * + *

{@code
+ * docker run -d -p 5432:5432 \
+ *   -e POSTGRES_DB=adk_test \
+ *   -e POSTGRES_USER=adk_user \
+ *   -e POSTGRES_PASSWORD=adk_password \
+ *   postgres:15
+ * }
+ */ +@Tag("integration") +public class PostgreSQLIntegrationTest { + + private static final String TEST_DB_URL = TestDatabaseConfig.POSTGRES_JDBC_URL; + private static final String TEST_APP_NAME = "table-test-app"; + private static final String TEST_USER_ID = "table-test-user"; + + private DatabaseSessionService sessionService; + + @BeforeEach + public void setUp() { + assumeTrue( + TestDatabaseConfig.isPostgreSQLAvailable(), + TestDatabaseConfig.getDatabaseNotAvailableMessage("PostgreSQL")); + + sessionService = new DatabaseSessionService(TEST_DB_URL); + } + + @AfterEach + public void tearDown() { + if (sessionService != null) { + sessionService.close(); + } + } + + // ==================== SESSIONS TABLE TESTS ==================== + + @Test + public void testSessionsTableCreation() throws SQLException { + String sessionId = "session-" + System.currentTimeMillis(); + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put("test_key", "test_value"); + + // Create session via service + Session session = + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId).blockingGet(); + + // Verify in database directly + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = + "SELECT app_name, user_id, id, state, create_time, update_time " + + "FROM sessions WHERE app_name = ? AND user_id = ? AND id = ?"; + + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, TEST_APP_NAME); + stmt.setString(2, TEST_USER_ID); + stmt.setString(3, sessionId); + + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next(), "Session should exist in sessions table"); + + assertEquals(TEST_APP_NAME, rs.getString("app_name")); + assertEquals(TEST_USER_ID, rs.getString("user_id")); + assertEquals(sessionId, rs.getString("id")); + + // Verify JSONB state + String stateJson = rs.getString("state"); + assertNotNull(stateJson); + assertTrue(stateJson.contains("test_key")); + assertTrue(stateJson.contains("test_value")); + + // Verify timestamps + Timestamp createTime = rs.getTimestamp("create_time"); + Timestamp updateTime = rs.getTimestamp("update_time"); + assertNotNull(createTime, "create_time should not be null"); + assertNotNull(updateTime, "update_time should not be null"); + + assertFalse(rs.next(), "Should only have one session with this ID"); + } + } + } + } + + @Test + public void testSessionsTableUpdate() throws SQLException { + String sessionId = "session-update-" + System.currentTimeMillis(); + + // Create initial session + Session session = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + Timestamp initialUpdateTime; + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = "SELECT update_time FROM sessions WHERE id = ?"; + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, sessionId); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next()); + initialUpdateTime = rs.getTimestamp("update_time"); + } + } + } + + // Wait a bit to ensure timestamp difference + try { + Thread.sleep(100); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + + // Update session by appending event + ConcurrentHashMap stateDelta = new ConcurrentHashMap<>(); + stateDelta.put("updated_key", "updated_value"); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test-author") + .content(Content.fromParts(Part.fromText("Update event"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(EventActions.builder().stateDelta(stateDelta).build()) + .build(); + + sessionService.appendEvent(session, event).blockingGet(); + + // Verify update_time changed + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = "SELECT state, update_time FROM sessions WHERE id = ?"; + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, sessionId); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next()); + + String stateJson = rs.getString("state"); + assertTrue(stateJson.contains("updated_key")); + assertTrue(stateJson.contains("updated_value")); + + Timestamp newUpdateTime = rs.getTimestamp("update_time"); + assertTrue( + newUpdateTime.after(initialUpdateTime), + "update_time should be updated after state change"); + } + } + } + } + + @Test + public void testSessionsTablePrimaryKey() throws SQLException { + String sessionId = "session-pk-" + System.currentTimeMillis(); + + // Create session + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + // Verify primary key constraint (app_name, user_id, id) + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + // Count sessions with this ID + String query = "SELECT COUNT(*) FROM sessions WHERE id = ?"; + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, sessionId); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next()); + assertEquals(1, rs.getInt(1), "Should have exactly one session with this ID"); + } + } + } + } + + // ==================== EVENTS TABLE TESTS ==================== + + @Test + public void testEventsTableCreation() throws SQLException { + String sessionId = "session-events-" + System.currentTimeMillis(); + + // Create session + Session session = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + // Create event + String eventId = UUID.randomUUID().toString(); + String invocationId = "inv-" + UUID.randomUUID().toString(); + long timestamp = Instant.now().toEpochMilli(); + + Event event = + Event.builder() + .id(eventId) + .invocationId(invocationId) + .author("test-author") + .content(Content.fromParts(Part.fromText("Test event content"))) + .timestamp(timestamp) + .build(); + + sessionService.appendEvent(session, event).blockingGet(); + + // Verify in database + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = + "SELECT id, app_name, user_id, session_id, invocation_id, timestamp, event_data " + + "FROM events WHERE id = ?"; + + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, eventId); + + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next(), "Event should exist in events table"); + + assertEquals(eventId, rs.getString("id")); + assertEquals(TEST_APP_NAME, rs.getString("app_name")); + assertEquals(TEST_USER_ID, rs.getString("user_id")); + assertEquals(sessionId, rs.getString("session_id")); + assertEquals(invocationId, rs.getString("invocation_id")); + + Timestamp eventTimestamp = rs.getTimestamp("timestamp"); + assertNotNull(eventTimestamp); + + // Verify JSONB event_data + String eventData = rs.getString("event_data"); + assertNotNull(eventData); + assertTrue(eventData.contains("Test event content")); + assertTrue(eventData.contains("test-author")); + + assertFalse(rs.next(), "Should only have one event with this ID"); + } + } + } + } + + @Test + public void testEventsTableMultipleEvents() throws SQLException { + String sessionId = "session-multi-events-" + System.currentTimeMillis(); + + // Create session + Session session = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + // Create multiple events + int numEvents = 5; + for (int i = 0; i < numEvents; i++) { + Event event = + Event.builder() + .id("event-" + i + "-" + UUID.randomUUID()) + .author("author-" + i) + .content(Content.fromParts(Part.fromText("Event " + i))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, event).blockingGet(); + } + + // Verify event count in database + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = + "SELECT COUNT(*) FROM events WHERE app_name = ? AND user_id = ? AND session_id = ?"; + + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, TEST_APP_NAME); + stmt.setString(2, TEST_USER_ID); + stmt.setString(3, sessionId); + + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next()); + assertEquals(numEvents, rs.getInt(1), "Should have " + numEvents + " events"); + } + } + } + } + + @Test + public void testEventsTableForeignKeyConstraint() throws SQLException { + String sessionId = "session-fk-" + System.currentTimeMillis(); + + // Create session + Session session = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + // Add event + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test-author") + .content(Content.fromParts(Part.fromText("FK test event"))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + sessionService.appendEvent(session, event).blockingGet(); + + // Verify events exist + int eventCountBefore; + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = "SELECT COUNT(*) FROM events WHERE session_id = ?"; + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, sessionId); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next()); + eventCountBefore = rs.getInt(1); + assertTrue(eventCountBefore > 0, "Should have at least one event"); + } + } + } + + // Delete session (should cascade delete events) + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String deleteQuery = "DELETE FROM sessions WHERE app_name = ? AND user_id = ? AND id = ?"; + try (PreparedStatement stmt = conn.prepareStatement(deleteQuery)) { + stmt.setString(1, TEST_APP_NAME); + stmt.setString(2, TEST_USER_ID); + stmt.setString(3, sessionId); + stmt.executeUpdate(); + } + } + + // Verify events were cascade deleted + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = "SELECT COUNT(*) FROM events WHERE session_id = ?"; + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, sessionId); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next()); + assertEquals(0, rs.getInt(1), "Events should be cascade deleted with session"); + } + } + } + } + + @Test + public void testEventsTableTimestampOrdering() throws SQLException { + String sessionId = "session-timestamp-" + System.currentTimeMillis(); + + // Create session + Session session = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + // Create events with specific timestamps + long baseTime = Instant.now().toEpochMilli(); + for (int i = 0; i < 3; i++) { + Event event = + Event.builder() + .id("event-" + i + "-" + UUID.randomUUID()) + .author("author") + .content(Content.fromParts(Part.fromText("Event " + i))) + .timestamp(baseTime + (i * 1000)) + .build(); + + session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, event).blockingGet(); + } + + // Verify events are ordered by timestamp + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = + "SELECT event_data->>'id' as event_id, timestamp " + + "FROM events WHERE session_id = ? ORDER BY timestamp ASC"; + + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, sessionId); + + try (ResultSet rs = stmt.executeQuery()) { + Timestamp previousTimestamp = null; + int count = 0; + + while (rs.next()) { + Timestamp currentTimestamp = rs.getTimestamp("timestamp"); + assertNotNull(currentTimestamp); + + if (previousTimestamp != null) { + assertTrue( + currentTimestamp.after(previousTimestamp) + || currentTimestamp.equals(previousTimestamp), + "Events should be ordered by timestamp"); + } + + previousTimestamp = currentTimestamp; + count++; + } + + assertEquals(3, count, "Should have 3 events"); + } + } + } + } + + // ==================== APP_STATES TABLE TESTS ==================== + + @Test + public void testAppStatesTableCreation() throws SQLException { + String sessionId = "session-app-state-" + System.currentTimeMillis(); + + // Create session with app state + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put("app:api_endpoint", "https://api.example.com"); + state.put("app:version", "2.0"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId).blockingGet(); + + // Verify in app_states table + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = "SELECT app_name, state, update_time FROM app_states WHERE app_name = ?"; + + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, TEST_APP_NAME); + + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next(), "App state should exist in app_states table"); + + assertEquals(TEST_APP_NAME, rs.getString("app_name")); + + String stateJson = rs.getString("state"); + assertNotNull(stateJson); + assertTrue(stateJson.contains("api_endpoint")); + assertTrue(stateJson.contains("https://api.example.com")); + assertTrue(stateJson.contains("version")); + assertTrue(stateJson.contains("2.0")); + + Timestamp updateTime = rs.getTimestamp("update_time"); + assertNotNull(updateTime, "update_time should not be null"); + } + } + } + } + + @Test + public void testAppStatesTableUpsert() throws SQLException { + String sessionId1 = "session-app-upsert-1-" + System.currentTimeMillis(); + String sessionId2 = "session-app-upsert-2-" + System.currentTimeMillis(); + + // Create first session with app state + ConcurrentHashMap state1 = new ConcurrentHashMap<>(); + state1.put("app:config", "version1"); + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state1, sessionId1).blockingGet(); + + Timestamp firstUpdateTime; + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = "SELECT update_time FROM app_states WHERE app_name = ?"; + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, TEST_APP_NAME); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next()); + firstUpdateTime = rs.getTimestamp("update_time"); + } + } + } + + // Wait to ensure timestamp difference + try { + Thread.sleep(100); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + + // Create second session with updated app state + ConcurrentHashMap state2 = new ConcurrentHashMap<>(); + state2.put("app:config", "version2"); + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state2, sessionId2).blockingGet(); + + // Verify app state was updated (not duplicated) + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = + "SELECT COUNT(*), state, update_time FROM app_states WHERE app_name = ? GROUP BY state, update_time"; + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, TEST_APP_NAME); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next(), "Should have app state"); + assertEquals(1, rs.getInt(1), "Should have only one row for this app"); + + String stateJson = rs.getString("state"); + assertTrue(stateJson.contains("version2"), "State should be updated to version2"); + + Timestamp newUpdateTime = rs.getTimestamp("update_time"); + assertTrue( + newUpdateTime.after(firstUpdateTime) || newUpdateTime.equals(firstUpdateTime), + "update_time should be updated"); + + assertFalse(rs.next(), "Should only have one app state row"); + } + } + } + } + + @Test + public void testAppStatesTablePrimaryKey() throws SQLException { + String sessionId = "session-app-pk-" + System.currentTimeMillis(); + + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put("app:key", "value"); + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId).blockingGet(); + + // Verify only one row per app_name + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = "SELECT COUNT(*) FROM app_states WHERE app_name = ?"; + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, TEST_APP_NAME); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next()); + assertEquals(1, rs.getInt(1), "Should have exactly one row per app_name"); + } + } + } + } + + // ==================== USER_STATES TABLE TESTS ==================== + + @Test + public void testUserStatesTableCreation() throws SQLException { + String sessionId = "session-user-state-" + System.currentTimeMillis(); + + // Create session with user state + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put("user:language", "English"); + state.put("user:timezone", "UTC"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId).blockingGet(); + + // Verify in user_states table + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = + "SELECT app_name, user_id, state, update_time FROM user_states WHERE app_name = ? AND user_id = ?"; + + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, TEST_APP_NAME); + stmt.setString(2, TEST_USER_ID); + + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next(), "User state should exist in user_states table"); + + assertEquals(TEST_APP_NAME, rs.getString("app_name")); + assertEquals(TEST_USER_ID, rs.getString("user_id")); + + String stateJson = rs.getString("state"); + assertNotNull(stateJson); + assertTrue(stateJson.contains("language")); + assertTrue(stateJson.contains("English")); + assertTrue(stateJson.contains("timezone")); + assertTrue(stateJson.contains("UTC")); + + Timestamp updateTime = rs.getTimestamp("update_time"); + assertNotNull(updateTime, "update_time should not be null"); + } + } + } + } + + @Test + public void testUserStatesTableUpsert() throws SQLException { + String sessionId1 = "session-user-upsert-1-" + System.currentTimeMillis(); + String sessionId2 = "session-user-upsert-2-" + System.currentTimeMillis(); + + // Create first session with user state + ConcurrentHashMap state1 = new ConcurrentHashMap<>(); + state1.put("user:preference", "dark"); + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state1, sessionId1).blockingGet(); + + Timestamp firstUpdateTime; + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = "SELECT update_time FROM user_states WHERE app_name = ? AND user_id = ?"; + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, TEST_APP_NAME); + stmt.setString(2, TEST_USER_ID); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next()); + firstUpdateTime = rs.getTimestamp("update_time"); + } + } + } + + // Wait to ensure timestamp difference + try { + Thread.sleep(100); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + + // Create second session with updated user state + ConcurrentHashMap state2 = new ConcurrentHashMap<>(); + state2.put("user:preference", "light"); + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state2, sessionId2).blockingGet(); + + // Verify user state was updated (not duplicated) + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = + "SELECT COUNT(*), state, update_time FROM user_states " + + "WHERE app_name = ? AND user_id = ? GROUP BY state, update_time"; + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, TEST_APP_NAME); + stmt.setString(2, TEST_USER_ID); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next(), "Should have user state"); + assertEquals(1, rs.getInt(1), "Should have only one row for this user"); + + String stateJson = rs.getString("state"); + assertTrue(stateJson.contains("light"), "State should be updated to 'light'"); + + Timestamp newUpdateTime = rs.getTimestamp("update_time"); + assertTrue( + newUpdateTime.after(firstUpdateTime) || newUpdateTime.equals(firstUpdateTime), + "update_time should be updated"); + + assertFalse(rs.next(), "Should only have one user state row"); + } + } + } + } + + @Test + public void testUserStatesTablePrimaryKey() throws SQLException { + String sessionId = "session-user-pk-" + System.currentTimeMillis(); + + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put("user:key", "value"); + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId).blockingGet(); + + // Verify only one row per (app_name, user_id) combination + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = "SELECT COUNT(*) FROM user_states WHERE app_name = ? AND user_id = ?"; + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, TEST_APP_NAME); + stmt.setString(2, TEST_USER_ID); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next()); + assertEquals(1, rs.getInt(1), "Should have exactly one row per (app_name, user_id)"); + } + } + } + } + + @Test + public void testUserStatesTableIsolationBetweenUsers() throws SQLException { + String user1 = "user-1-" + System.currentTimeMillis(); + String user2 = "user-2-" + System.currentTimeMillis(); + + // Create sessions for two different users + ConcurrentHashMap state1 = new ConcurrentHashMap<>(); + state1.put("user:language", "French"); + sessionService.createSession(TEST_APP_NAME, user1, state1, "session-1").blockingGet(); + + ConcurrentHashMap state2 = new ConcurrentHashMap<>(); + state2.put("user:language", "Spanish"); + sessionService.createSession(TEST_APP_NAME, user2, state2, "session-2").blockingGet(); + + // Verify both users have separate state + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = + "SELECT user_id, state FROM user_states WHERE app_name = ? AND user_id IN (?, ?)"; + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, TEST_APP_NAME); + stmt.setString(2, user1); + stmt.setString(3, user2); + try (ResultSet rs = stmt.executeQuery()) { + int count = 0; + while (rs.next()) { + String userId = rs.getString("user_id"); + String stateJson = rs.getString("state"); + + if (userId.equals(user1)) { + assertTrue(stateJson.contains("French"), "User 1 should have French"); + } else if (userId.equals(user2)) { + assertTrue(stateJson.contains("Spanish"), "User 2 should have Spanish"); + } + count++; + } + assertEquals(2, count, "Should have state for both users"); + } + } + } + } + + // ==================== CROSS-TABLE INTEGRATION TESTS ==================== + + @Test + public void testAllTablesIntegration() throws SQLException { + String sessionId = "session-integration-" + System.currentTimeMillis(); + + // Create session with all state types + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put("session_key", "session_value"); // session state + state.put("app:api_key", "app-12345"); // app state + state.put("user:theme", "dark"); // user state + + Session session = + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId).blockingGet(); + + // Add event + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("integration-test") + .content(Content.fromParts(Part.fromText("Integration test event"))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + sessionService.appendEvent(session, event).blockingGet(); + + // Verify all tables have data + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + // Check sessions table + String sessionQuery = "SELECT COUNT(*) FROM sessions WHERE id = ?"; + try (PreparedStatement stmt = conn.prepareStatement(sessionQuery)) { + stmt.setString(1, sessionId); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next()); + assertEquals(1, rs.getInt(1), "Session should exist"); + } + } + + // Check events table + String eventQuery = "SELECT COUNT(*) FROM events WHERE session_id = ?"; + try (PreparedStatement stmt = conn.prepareStatement(eventQuery)) { + stmt.setString(1, sessionId); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next()); + assertTrue(rs.getInt(1) > 0, "Should have at least one event"); + } + } + + // Check app_states table + String appQuery = "SELECT state FROM app_states WHERE app_name = ?"; + try (PreparedStatement stmt = conn.prepareStatement(appQuery)) { + stmt.setString(1, TEST_APP_NAME); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next(), "App state should exist"); + String appState = rs.getString("state"); + assertTrue(appState.contains("api_key"), "App state should contain api_key"); + } + } + + // Check user_states table + String userQuery = "SELECT state FROM user_states WHERE app_name = ? AND user_id = ?"; + try (PreparedStatement stmt = conn.prepareStatement(userQuery)) { + stmt.setString(1, TEST_APP_NAME); + stmt.setString(2, TEST_USER_ID); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next(), "User state should exist"); + String userState = rs.getString("state"); + assertTrue(userState.contains("theme"), "User state should contain theme"); + } + } + } + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/ReadTwiceNonDestructiveTest.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/ReadTwiceNonDestructiveTest.java new file mode 100644 index 000000000..224c6c2b9 --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/ReadTwiceNonDestructiveTest.java @@ -0,0 +1,267 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +import com.google.adk.events.Event; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import java.time.Instant; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import org.flywaydb.core.Flyway; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class ReadTwiceNonDestructiveTest { + + private static final String TEST_DB_URL = + "jdbc:h2:mem:read_twice_test;DB_CLOSE_DELAY=-1;USER=sa;PASSWORD="; + private static final String TEST_APP_NAME = "read-twice-test-app"; + private static final String TEST_USER_ID = "read-twice-user"; + + private DatabaseSessionService sessionService; + + @BeforeEach + public void setUp() { + Flyway flyway = + Flyway.configure() + .dataSource(TEST_DB_URL, null, null) + .locations("classpath:db/migration/h2") + .cleanDisabled(false) + .load(); + flyway.clean(); + flyway.migrate(); + + sessionService = new DatabaseSessionService(TEST_DB_URL); + } + + @AfterEach + public void tearDown() { + if (sessionService != null) { + sessionService.close(); + } + } + + @Test + public void testGetSessionTwiceDoesNotDeleteEvents() { + String sessionId = "non-destructive-read-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + for (int i = 1; i <= 5; i++) { + Event event = + Event.builder() + .id("event-" + i) + .author("test-author") + .content(Content.fromParts(Part.fromText("Event " + i))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + } + + Session firstRead = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + assertNotNull(firstRead); + assertEquals(5, firstRead.events().size(), "First read should return 5 events"); + + Session secondRead = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + assertNotNull(secondRead); + assertEquals(5, secondRead.events().size(), "Second read should still return 5 events"); + + Session thirdRead = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + assertNotNull(thirdRead); + assertEquals(5, thirdRead.events().size(), "Third read should still return 5 events"); + } + + @Test + public void testGetSessionMultipleTimesWithFiltering() { + String sessionId = "filter-read-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + Instant startTime = Instant.now(); + + for (int i = 1; i <= 10; i++) { + Event event = + Event.builder() + .id("event-" + i) + .author("test-author") + .content(Content.fromParts(Part.fromText("Event " + i))) + .timestamp(startTime.plusSeconds(i).toEpochMilli()) + .build(); + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + } + + Session allEvents = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + assertEquals(10, allEvents.events().size()); + + GetSessionConfig recentConfig = GetSessionConfig.builder().numRecentEvents(3).build(); + Session recentEvents = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.of(recentConfig)) + .blockingGet(); + assertEquals(3, recentEvents.events().size()); + + Session allEventsAgain = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + assertEquals(10, allEventsAgain.events().size(), "All events should still exist in DB"); + + GetSessionConfig timestampConfig = + GetSessionConfig.builder().afterTimestamp(startTime.plusSeconds(5)).build(); + Session filteredByTime = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.of(timestampConfig)) + .blockingGet(); + assertEquals(5, filteredByTime.events().size()); + + Session finalRead = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + assertEquals(10, finalRead.events().size(), "Events should persist after all filtered reads"); + } + + @Test + public void testConcurrentReadsDoNotAffectData() throws InterruptedException { + String sessionId = "concurrent-read-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + for (int i = 1; i <= 5; i++) { + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test-author") + .content(Content.fromParts(Part.fromText("Event " + i))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + } + + Thread[] readers = new Thread[10]; + for (int i = 0; i < 10; i++) { + readers[i] = + new Thread( + () -> { + for (int j = 0; j < 5; j++) { + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + assertEquals( + 5, session.events().size(), "Each concurrent read should return 5 events"); + } + }); + readers[i].start(); + } + + for (Thread reader : readers) { + reader.join(); + } + + Session finalCheck = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + assertEquals( + 5, finalCheck.events().size(), "Events should remain intact after concurrent reads"); + } + + @Test + public void testReadWithDifferentConfigs() { + String sessionId = "config-variation-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + Instant startTime = Instant.now(); + + for (int i = 1; i <= 10; i++) { + Event event = + Event.builder() + .id("event-" + i) + .author("test-author") + .content(Content.fromParts(Part.fromText("Event " + i))) + .timestamp(startTime.plusSeconds(i).toEpochMilli()) + .build(); + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + } + + for (int recentCount = 1; recentCount <= 10; recentCount++) { + GetSessionConfig config = GetSessionConfig.builder().numRecentEvents(recentCount).build(); + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.of(config)) + .blockingGet(); + assertEquals( + recentCount, session.events().size(), "Should get " + recentCount + " recent events"); + } + + Session fullRead = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + assertEquals( + 10, + fullRead.events().size(), + "All 10 events should still exist after multiple config reads"); + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/SessionUpdateTimeTest.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/SessionUpdateTimeTest.java new file mode 100644 index 000000000..6c9fb9ff0 --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/SessionUpdateTimeTest.java @@ -0,0 +1,361 @@ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.*; + +import com.google.adk.events.Event; +import com.google.adk.events.EventActions; +import java.time.Instant; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +/** + * Tests that verify the session's update_time is correctly updated in all scenarios when appending + * events. + * + *

These tests address the bug where session update_time was not being updated when events + * contained only app: or user: prefixed state deltas. + */ +class SessionUpdateTimeTest { + + private static final String APP_NAME = "testApp"; + private static final String USER_ID = "testUser"; + + private DatabaseSessionService sessionService; + + @BeforeEach + void setUp() { + String jdbcUrl = + "jdbc:h2:mem:session_update_time_test_" + + UUID.randomUUID() + + ";DB_CLOSE_DELAY=-1;MODE=PostgreSQL"; + sessionService = new DatabaseSessionService(jdbcUrl); + } + + @AfterEach + void tearDown() { + if (sessionService != null) { + sessionService.close(); + } + } + + @Test + void testSessionUpdateTime_AppendEventWithNoStateDelta() throws InterruptedException { + Session session = + sessionService + .createSession(APP_NAME, USER_ID, new ConcurrentHashMap<>(), null) + .blockingGet(); + + Instant initialUpdateTime = session.lastUpdateTime(); + assertNotNull(initialUpdateTime, "Initial update time should not be null"); + + Thread.sleep(100); + + Event eventWithNoState = + Event.builder() + .id(UUID.randomUUID().toString()) + .invocationId("inv-1") + .timestamp(System.currentTimeMillis()) + .build(); + + sessionService.appendEvent(session, eventWithNoState).blockingGet(); + + Session updatedSession = + sessionService + .getSession(APP_NAME, USER_ID, session.id(), java.util.Optional.empty()) + .blockingGet(); + + assertTrue( + updatedSession.lastUpdateTime().isAfter(initialUpdateTime), + "Session update_time should be updated even with no state delta"); + } + + @Test + void testSessionUpdateTime_AppendEventWithOnlyAppStateDelta() throws InterruptedException { + Session session = + sessionService + .createSession(APP_NAME, USER_ID, new ConcurrentHashMap<>(), null) + .blockingGet(); + + Instant initialUpdateTime = session.lastUpdateTime(); + assertNotNull(initialUpdateTime, "Initial update time should not be null"); + + Thread.sleep(100); + + ConcurrentMap appOnlyDelta = new ConcurrentHashMap<>(); + appOnlyDelta.put("app:setting1", "value1"); + appOnlyDelta.put("app:setting2", 42); + + Event eventWithAppState = + Event.builder() + .id(UUID.randomUUID().toString()) + .invocationId("inv-2") + .timestamp(System.currentTimeMillis()) + .actions(EventActions.builder().stateDelta(appOnlyDelta).build()) + .build(); + + sessionService.appendEvent(session, eventWithAppState).blockingGet(); + + Session updatedSession = + sessionService + .getSession(APP_NAME, USER_ID, session.id(), java.util.Optional.empty()) + .blockingGet(); + + assertTrue( + updatedSession.lastUpdateTime().isAfter(initialUpdateTime), + "Session update_time should be updated when event has only app: prefixed state delta"); + } + + @Test + void testSessionUpdateTime_AppendEventWithOnlyUserStateDelta() throws InterruptedException { + Session session = + sessionService + .createSession(APP_NAME, USER_ID, new ConcurrentHashMap<>(), null) + .blockingGet(); + + Instant initialUpdateTime = session.lastUpdateTime(); + assertNotNull(initialUpdateTime, "Initial update time should not be null"); + + Thread.sleep(100); + + ConcurrentMap userOnlyDelta = new ConcurrentHashMap<>(); + userOnlyDelta.put("user:preference1", "dark"); + userOnlyDelta.put("user:preference2", true); + + Event eventWithUserState = + Event.builder() + .id(UUID.randomUUID().toString()) + .invocationId("inv-3") + .timestamp(System.currentTimeMillis()) + .actions(EventActions.builder().stateDelta(userOnlyDelta).build()) + .build(); + + sessionService.appendEvent(session, eventWithUserState).blockingGet(); + + Session updatedSession = + sessionService + .getSession(APP_NAME, USER_ID, session.id(), java.util.Optional.empty()) + .blockingGet(); + + assertTrue( + updatedSession.lastUpdateTime().isAfter(initialUpdateTime), + "Session update_time should be updated when event has only user: prefixed state delta"); + } + + @Test + void testSessionUpdateTime_AppendEventWithSessionStateDelta() throws InterruptedException { + Session session = + sessionService + .createSession(APP_NAME, USER_ID, new ConcurrentHashMap<>(), null) + .blockingGet(); + + Instant initialUpdateTime = session.lastUpdateTime(); + assertNotNull(initialUpdateTime, "Initial update time should not be null"); + + Thread.sleep(100); + + ConcurrentMap sessionDelta = new ConcurrentHashMap<>(); + sessionDelta.put("counter", 1); + sessionDelta.put("status", "active"); + + Event eventWithSessionState = + Event.builder() + .id(UUID.randomUUID().toString()) + .invocationId("inv-4") + .timestamp(System.currentTimeMillis()) + .actions(EventActions.builder().stateDelta(sessionDelta).build()) + .build(); + + sessionService.appendEvent(session, eventWithSessionState).blockingGet(); + + Session updatedSession = + sessionService + .getSession(APP_NAME, USER_ID, session.id(), java.util.Optional.empty()) + .blockingGet(); + + assertTrue( + updatedSession.lastUpdateTime().isAfter(initialUpdateTime), + "Session update_time should be updated when event has session state delta"); + } + + @Test + void testSessionUpdateTime_AppendEventWithMixedStateDelta() throws InterruptedException { + Session session = + sessionService + .createSession(APP_NAME, USER_ID, new ConcurrentHashMap<>(), null) + .blockingGet(); + + Instant initialUpdateTime = session.lastUpdateTime(); + assertNotNull(initialUpdateTime, "Initial update time should not be null"); + + Thread.sleep(100); + + ConcurrentMap mixedDelta = new ConcurrentHashMap<>(); + mixedDelta.put("app:version", "2.0"); + mixedDelta.put("user:theme", "light"); + mixedDelta.put("currentStep", 5); + + Event eventWithMixedState = + Event.builder() + .id(UUID.randomUUID().toString()) + .invocationId("inv-5") + .timestamp(System.currentTimeMillis()) + .actions(EventActions.builder().stateDelta(mixedDelta).build()) + .build(); + + sessionService.appendEvent(session, eventWithMixedState).blockingGet(); + + Session updatedSession = + sessionService + .getSession(APP_NAME, USER_ID, session.id(), java.util.Optional.empty()) + .blockingGet(); + + assertTrue( + updatedSession.lastUpdateTime().isAfter(initialUpdateTime), + "Session update_time should be updated when event has mixed state delta"); + } + + @Test + void testSessionUpdateTime_AppendEventWithOnlyTempStateDelta() throws InterruptedException { + Session session = + sessionService + .createSession(APP_NAME, USER_ID, new ConcurrentHashMap<>(), null) + .blockingGet(); + + Instant initialUpdateTime = session.lastUpdateTime(); + assertNotNull(initialUpdateTime, "Initial update time should not be null"); + + Thread.sleep(100); + + ConcurrentMap tempOnlyDelta = new ConcurrentHashMap<>(); + tempOnlyDelta.put("temp:cache", "somevalue"); + tempOnlyDelta.put("temp:ui_state", Map.of("x", 10)); + + Event eventWithTempState = + Event.builder() + .id(UUID.randomUUID().toString()) + .invocationId("inv-6") + .timestamp(System.currentTimeMillis()) + .actions(EventActions.builder().stateDelta(tempOnlyDelta).build()) + .build(); + + sessionService.appendEvent(session, eventWithTempState).blockingGet(); + + Session updatedSession = + sessionService + .getSession(APP_NAME, USER_ID, session.id(), java.util.Optional.empty()) + .blockingGet(); + + assertTrue( + updatedSession.lastUpdateTime().isAfter(initialUpdateTime), + "Session update_time should be updated even when event has only temp: prefixed state delta"); + } + + @Test + void testSessionUpdateTime_MultipleEventsInSequence() throws InterruptedException { + Session session = + sessionService + .createSession(APP_NAME, USER_ID, new ConcurrentHashMap<>(), null) + .blockingGet(); + + Instant time1 = session.lastUpdateTime(); + Thread.sleep(100); + + Event event1 = + Event.builder() + .id(UUID.randomUUID().toString()) + .invocationId("inv-seq-1") + .timestamp(System.currentTimeMillis()) + .actions( + EventActions.builder() + .stateDelta(new ConcurrentHashMap<>(Map.of("app:config", "v1"))) + .build()) + .build(); + + sessionService.appendEvent(session, event1).blockingGet(); + Session session2 = + sessionService + .getSession(APP_NAME, USER_ID, session.id(), java.util.Optional.empty()) + .blockingGet(); + Instant time2 = session2.lastUpdateTime(); + assertTrue(time2.isAfter(time1), "Update time should increase after first event"); + + Thread.sleep(100); + + Event event2 = + Event.builder() + .id(UUID.randomUUID().toString()) + .invocationId("inv-seq-2") + .timestamp(System.currentTimeMillis()) + .actions( + EventActions.builder() + .stateDelta(new ConcurrentHashMap<>(Map.of("user:pref", "v2"))) + .build()) + .build(); + + sessionService.appendEvent(session2, event2).blockingGet(); + Session session3 = + sessionService + .getSession(APP_NAME, USER_ID, session.id(), java.util.Optional.empty()) + .blockingGet(); + Instant time3 = session3.lastUpdateTime(); + assertTrue(time3.isAfter(time2), "Update time should increase after second event"); + + Thread.sleep(100); + + Event event3 = + Event.builder() + .id(UUID.randomUUID().toString()) + .invocationId("inv-seq-3") + .timestamp(System.currentTimeMillis()) + .build(); + + sessionService.appendEvent(session3, event3).blockingGet(); + Session session4 = + sessionService + .getSession(APP_NAME, USER_ID, session.id(), java.util.Optional.empty()) + .blockingGet(); + Instant time4 = session4.lastUpdateTime(); + assertTrue(time4.isAfter(time3), "Update time should increase after third event with no state"); + } + + @Test + void testSessionUpdateTime_AppendEventWithAppAndUserStateDeltaOnly() throws InterruptedException { + Session session = + sessionService + .createSession(APP_NAME, USER_ID, new ConcurrentHashMap<>(), null) + .blockingGet(); + + Instant initialUpdateTime = session.lastUpdateTime(); + assertNotNull(initialUpdateTime, "Initial update time should not be null"); + + Thread.sleep(100); + + ConcurrentMap appAndUserDelta = new ConcurrentHashMap<>(); + appAndUserDelta.put("app:globalSetting", "enabled"); + appAndUserDelta.put("user:personalSetting", "custom"); + + Event eventWithAppAndUserState = + Event.builder() + .id(UUID.randomUUID().toString()) + .invocationId("inv-7") + .timestamp(System.currentTimeMillis()) + .actions(EventActions.builder().stateDelta(appAndUserDelta).build()) + .build(); + + sessionService.appendEvent(session, eventWithAppAndUserState).blockingGet(); + + Session updatedSession = + sessionService + .getSession(APP_NAME, USER_ID, session.id(), java.util.Optional.empty()) + .blockingGet(); + + assertTrue( + updatedSession.lastUpdateTime().isAfter(initialUpdateTime), + "Session update_time should be updated when event has app and user state delta but no session state delta"); + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/SpannerFunctionalTest.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/SpannerFunctionalTest.java new file mode 100644 index 000000000..df6893668 --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/SpannerFunctionalTest.java @@ -0,0 +1,184 @@ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import com.google.adk.events.Event; +import com.google.adk.events.EventActions; +import com.google.adk.testing.TestDatabaseConfig; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import java.time.Instant; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; + +@Tag("integration") +public class SpannerFunctionalTest { + + private static final String TEST_DB_URL = TestDatabaseConfig.SPANNER_JDBC_URL; + private String TEST_APP_NAME; + private String TEST_USER_ID; + + private DatabaseSessionService sessionService; + + @BeforeEach + public void setUp() { + assumeTrue( + TestDatabaseConfig.isSpannerAvailable(), + TestDatabaseConfig.getDatabaseNotAvailableMessage("Spanner")); + + TEST_APP_NAME = "jdbc-spanner-test-app-" + System.currentTimeMillis(); + TEST_USER_ID = "jdbc-spanner-test-user-" + System.currentTimeMillis(); + + sessionService = new DatabaseSessionService(TEST_DB_URL); + } + + @AfterEach + public void tearDown() { + if (sessionService != null) { + sessionService.close(); + } + } + + @Test + public void testBasicSessionOperations() { + String sessionId = "spanner-basic-test-" + System.currentTimeMillis(); + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put("key", "value"); + + Session session = + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId).blockingGet(); + + assertNotNull(session); + assertEquals(sessionId, session.id()); + assertEquals("value", session.state().get("key")); + } + + @Test + public void testEventActionsWithStateDelta() { + String sessionId = "spanner-actions-test-" + System.currentTimeMillis(); + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + ConcurrentHashMap stateDelta = new ConcurrentHashMap<>(); + stateDelta.put("count", 1); + stateDelta.put("_app_shared", "global"); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test-author") + .content(Content.fromParts(Part.fromText("Test event"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(EventActions.builder().stateDelta(stateDelta).build()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(retrieved); + assertEquals(1, retrieved.state().get("count")); + assertEquals("global", retrieved.state().get("_app_shared")); + } + + @Test + public void testJSONStorageAndRetrieval() { + String sessionId = "spanner-json-test-" + System.currentTimeMillis(); + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put("nested", java.util.Map.of("inner", "value")); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId).blockingGet(); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(retrieved); + assertNotNull(retrieved.state().get("nested")); + } + + @Test + public void testUpsertAppState() { + String sessionId1 = "spanner-upsert-1-" + System.currentTimeMillis(); + String sessionId2 = "spanner-upsert-2-" + System.currentTimeMillis(); + + ConcurrentHashMap state1 = new ConcurrentHashMap<>(); + state1.put("app:config", "value1"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state1, sessionId1).blockingGet(); + + ConcurrentHashMap state2 = new ConcurrentHashMap<>(); + state2.put("app:config", "value2"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state2, sessionId2).blockingGet(); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId1, Optional.empty()) + .blockingGet(); + + assertNotNull(retrieved); + assertEquals("value2", retrieved.state().get("app:config")); + } + + @Test + public void testGetSessionWithInvalidConfig() { + String sessionId = "spanner-invalid-config-test-" + System.currentTimeMillis(); + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + // Add 5 events + for (int i = 1; i <= 5; i++) { + Event event = + Event.builder() + .id("event-" + i) + .author("test-author") + .content(Content.fromParts(Part.fromText("Event " + i))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + } + + // Test negative numRecentEvents: -1 should be treated as abs(-1) = 1 (last 1 event) + GetSessionConfig negativeNumEvents = GetSessionConfig.builder().numRecentEvents(-1).build(); + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.of(negativeNumEvents)) + .blockingGet(); + + assertNotNull(session); + // Should return exactly 1 event (the most recent one) + assertEquals( + 1, + session.events().size(), + "Expected 1 event for numRecentEvents=-1, got " + session.events().size()); + // Should be the last event added (event-5) + assertEquals( + "event-5", + session.events().get(0).id(), + "Expected most recent event (event-5), got " + session.events().get(0).id()); + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/SpannerIntegrationTest.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/SpannerIntegrationTest.java new file mode 100644 index 000000000..48250f1ba --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/SpannerIntegrationTest.java @@ -0,0 +1,808 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import com.google.adk.events.Event; +import com.google.adk.events.EventActions; +import com.google.adk.testing.TestDatabaseConfig; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Timestamp; +import java.time.Instant; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; + +/** + * Integration tests that verify Spanner database table operations directly. + * + *

This test suite validates: - sessions table: CRUD operations, timestamps, state storage - + * events table: event persistence, foreign key relationships, cascading deletes - app_states table: + * application-wide state management - user_states table: user-specific state management - STRING + * (JSON) storage and retrieval - Timestamp tracking (create_time, update_time) + * + *

Prerequisites: Ensure Spanner instance and database are configured. Set environment variables: + * SPANNER_PROJECT_ID, SPANNER_INSTANCE_ID, SPANNER_DATABASE_ID + */ +@Tag("integration") +public class SpannerIntegrationTest { + + private static final String TEST_DB_URL = TestDatabaseConfig.SPANNER_JDBC_URL; + private static final String TEST_APP_NAME = "spanner-table-test-app"; + private static final String TEST_USER_ID = "spanner-table-test-user"; + + private DatabaseSessionService sessionService; + + @BeforeEach + public void setUp() { + assumeTrue( + TestDatabaseConfig.isSpannerAvailable(), + TestDatabaseConfig.getDatabaseNotAvailableMessage("Spanner")); + + sessionService = new DatabaseSessionService(TEST_DB_URL); + } + + @AfterEach + public void tearDown() { + if (sessionService != null) { + sessionService.close(); + } + } + + // ==================== SESSIONS TABLE TESTS ==================== + + @Test + public void testSessionsTableCreation() throws SQLException { + String sessionId = "session-" + System.currentTimeMillis(); + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put("test_key", "test_value"); + + // Create session via service + Session session = + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId).blockingGet(); + + // Verify in database directly + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = + "SELECT app_name, user_id, id, state, create_time, update_time " + + "FROM sessions WHERE app_name = ? AND user_id = ? AND id = ?"; + + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, TEST_APP_NAME); + stmt.setString(2, TEST_USER_ID); + stmt.setString(3, sessionId); + + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next(), "Session should exist in sessions table"); + + assertEquals(TEST_APP_NAME, rs.getString("app_name")); + assertEquals(TEST_USER_ID, rs.getString("user_id")); + assertEquals(sessionId, rs.getString("id")); + + // Verify STRING state (stored as JSON string) + String stateJson = rs.getString("state"); + assertNotNull(stateJson); + assertTrue(stateJson.contains("test_key")); + assertTrue(stateJson.contains("test_value")); + + // Verify timestamps + Timestamp createTime = rs.getTimestamp("create_time"); + Timestamp updateTime = rs.getTimestamp("update_time"); + assertNotNull(createTime, "create_time should not be null"); + assertNotNull(updateTime, "update_time should not be null"); + + assertFalse(rs.next(), "Should only have one session with this ID"); + } + } + } + } + + @Test + public void testSessionsTableUpdate() throws SQLException { + String sessionId = "session-update-" + System.currentTimeMillis(); + + // Create initial session + Session session = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + Timestamp initialUpdateTime; + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = "SELECT update_time FROM sessions WHERE id = ?"; + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, sessionId); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next()); + initialUpdateTime = rs.getTimestamp("update_time"); + } + } + } + + // Wait a bit to ensure timestamp difference + try { + Thread.sleep(100); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + + // Update session by appending event + ConcurrentHashMap stateDelta = new ConcurrentHashMap<>(); + stateDelta.put("updated_key", "updated_value"); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test-author") + .content(Content.fromParts(Part.fromText("Update event"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(EventActions.builder().stateDelta(stateDelta).build()) + .build(); + + sessionService.appendEvent(session, event).blockingGet(); + + // Verify update_time changed + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = "SELECT state, update_time FROM sessions WHERE id = ?"; + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, sessionId); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next()); + + String stateJson = rs.getString("state"); + assertTrue(stateJson.contains("updated_key")); + assertTrue(stateJson.contains("updated_value")); + + Timestamp newUpdateTime = rs.getTimestamp("update_time"); + assertTrue( + newUpdateTime.after(initialUpdateTime), + "update_time should be updated after state change"); + } + } + } + } + + @Test + public void testSessionsTablePrimaryKey() throws SQLException { + String sessionId = "session-pk-" + System.currentTimeMillis(); + + // Create session + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + // Verify primary key constraint (app_name, user_id, id) + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + // Count sessions with this ID + String query = "SELECT COUNT(*) FROM sessions WHERE id = ?"; + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, sessionId); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next()); + assertEquals(1, rs.getInt(1), "Should have exactly one session with this ID"); + } + } + } + } + + // ==================== EVENTS TABLE TESTS ==================== + + @Test + public void testEventsTableCreation() throws SQLException { + String sessionId = "session-events-" + System.currentTimeMillis(); + + // Create session + Session session = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + // Create event + String eventId = UUID.randomUUID().toString(); + String invocationId = "inv-" + UUID.randomUUID().toString(); + long timestamp = Instant.now().toEpochMilli(); + + Event event = + Event.builder() + .id(eventId) + .invocationId(invocationId) + .author("test-author") + .content(Content.fromParts(Part.fromText("Test event content"))) + .timestamp(timestamp) + .build(); + + sessionService.appendEvent(session, event).blockingGet(); + + // Verify in database + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = + "SELECT id, app_name, user_id, session_id, invocation_id, timestamp, event_data " + + "FROM events WHERE id = ?"; + + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, eventId); + + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next(), "Event should exist in events table"); + + assertEquals(eventId, rs.getString("id")); + assertEquals(TEST_APP_NAME, rs.getString("app_name")); + assertEquals(TEST_USER_ID, rs.getString("user_id")); + assertEquals(sessionId, rs.getString("session_id")); + assertEquals(invocationId, rs.getString("invocation_id")); + + Timestamp eventTimestamp = rs.getTimestamp("timestamp"); + assertNotNull(eventTimestamp); + + // Verify STRING event_data (stored as JSON string) + String eventData = rs.getString("event_data"); + assertNotNull(eventData); + assertTrue(eventData.contains("Test event content")); + assertTrue(eventData.contains("test-author")); + + assertFalse(rs.next(), "Should only have one event with this ID"); + } + } + } + } + + @Test + public void testEventsTableMultipleEvents() throws SQLException { + String sessionId = "session-multi-events-" + System.currentTimeMillis(); + + // Create session + Session session = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + // Create multiple events + int numEvents = 5; + for (int i = 0; i < numEvents; i++) { + Event event = + Event.builder() + .id("event-" + i + "-" + UUID.randomUUID()) + .author("author-" + i) + .content(Content.fromParts(Part.fromText("Event " + i))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, event).blockingGet(); + } + + // Verify event count in database + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = + "SELECT COUNT(*) FROM events WHERE app_name = ? AND user_id = ? AND session_id = ?"; + + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, TEST_APP_NAME); + stmt.setString(2, TEST_USER_ID); + stmt.setString(3, sessionId); + + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next()); + assertEquals(numEvents, rs.getInt(1), "Should have " + numEvents + " events"); + } + } + } + } + + @Test + public void testEventsTableForeignKeyConstraint() throws SQLException { + String sessionId = "session-fk-" + System.currentTimeMillis(); + + // Create session + Session session = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + // Add event + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test-author") + .content(Content.fromParts(Part.fromText("FK test event"))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + sessionService.appendEvent(session, event).blockingGet(); + + // Verify events exist + int eventCountBefore; + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = "SELECT COUNT(*) FROM events WHERE session_id = ?"; + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, sessionId); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next()); + eventCountBefore = rs.getInt(1); + assertTrue(eventCountBefore > 0, "Should have at least one event"); + } + } + } + + // Spanner doesn't support CASCADE DELETE - manually delete events first + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + // First delete events + String deleteEventsQuery = "DELETE FROM events WHERE session_id = ?"; + try (PreparedStatement stmt = conn.prepareStatement(deleteEventsQuery)) { + stmt.setString(1, sessionId); + stmt.executeUpdate(); + } + + // Then delete session + String deleteSessionQuery = + "DELETE FROM sessions WHERE app_name = ? AND user_id = ? AND id = ?"; + try (PreparedStatement stmt = conn.prepareStatement(deleteSessionQuery)) { + stmt.setString(1, TEST_APP_NAME); + stmt.setString(2, TEST_USER_ID); + stmt.setString(3, sessionId); + stmt.executeUpdate(); + } + } + + // Verify events were deleted + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = "SELECT COUNT(*) FROM events WHERE session_id = ?"; + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, sessionId); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next()); + assertEquals(0, rs.getInt(1), "Events should be deleted with session"); + } + } + } + } + + @Test + public void testEventsTableTimestampOrdering() throws SQLException { + String sessionId = "session-timestamp-" + System.currentTimeMillis(); + + // Create session + Session session = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + // Create events with specific timestamps + long baseTime = Instant.now().toEpochMilli(); + for (int i = 0; i < 3; i++) { + Event event = + Event.builder() + .id("event-" + i + "-" + UUID.randomUUID()) + .author("author") + .content(Content.fromParts(Part.fromText("Event " + i))) + .timestamp(baseTime + (i * 1000)) + .build(); + + session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, event).blockingGet(); + } + + // Verify events are ordered by timestamp + // Note: Spanner doesn't support JSON operators like PostgreSQL/MySQL + // We'll just verify timestamp ordering without extracting JSON fields + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = + "SELECT id, timestamp " + "FROM events WHERE session_id = ? ORDER BY timestamp ASC"; + + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, sessionId); + + try (ResultSet rs = stmt.executeQuery()) { + Timestamp previousTimestamp = null; + int count = 0; + + while (rs.next()) { + Timestamp currentTimestamp = rs.getTimestamp("timestamp"); + assertNotNull(currentTimestamp); + + if (previousTimestamp != null) { + assertTrue( + currentTimestamp.after(previousTimestamp) + || currentTimestamp.equals(previousTimestamp), + "Events should be ordered by timestamp"); + } + + previousTimestamp = currentTimestamp; + count++; + } + + assertEquals(3, count, "Should have 3 events"); + } + } + } + } + + // ==================== APP_STATES TABLE TESTS ==================== + + @Test + public void testAppStatesTableCreation() throws SQLException { + String sessionId = "session-app-state-" + System.currentTimeMillis(); + + // Create session with app state + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put("app:api_endpoint", "https://api.example.com"); + state.put("app:version", "2.0"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId).blockingGet(); + + // Verify in app_states table + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = "SELECT app_name, state, update_time FROM app_states WHERE app_name = ?"; + + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, TEST_APP_NAME); + + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next(), "App state should exist in app_states table"); + + assertEquals(TEST_APP_NAME, rs.getString("app_name")); + + String stateJson = rs.getString("state"); + assertNotNull(stateJson); + assertTrue(stateJson.contains("api_endpoint")); + assertTrue(stateJson.contains("https://api.example.com")); + assertTrue(stateJson.contains("version")); + assertTrue(stateJson.contains("2.0")); + + Timestamp updateTime = rs.getTimestamp("update_time"); + assertNotNull(updateTime, "update_time should not be null"); + } + } + } + } + + @Test + public void testAppStatesTableUpsert() throws SQLException { + String sessionId1 = "session-app-upsert-1-" + System.currentTimeMillis(); + String sessionId2 = "session-app-upsert-2-" + System.currentTimeMillis(); + + // Create first session with app state + ConcurrentHashMap state1 = new ConcurrentHashMap<>(); + state1.put("app:config", "version1"); + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state1, sessionId1).blockingGet(); + + Timestamp firstUpdateTime; + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = "SELECT update_time FROM app_states WHERE app_name = ?"; + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, TEST_APP_NAME); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next()); + firstUpdateTime = rs.getTimestamp("update_time"); + } + } + } + + // Wait to ensure timestamp difference + try { + Thread.sleep(100); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + + // Create second session with updated app state + ConcurrentHashMap state2 = new ConcurrentHashMap<>(); + state2.put("app:config", "version2"); + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state2, sessionId2).blockingGet(); + + // Verify app state was updated (not duplicated) + // Spanner doesn't support GROUP BY on JSON (STRING) columns + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = "SELECT state, update_time FROM app_states WHERE app_name = ?"; + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, TEST_APP_NAME); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next(), "Should have app state"); + + String stateJson = rs.getString("state"); + assertTrue(stateJson.contains("version2"), "State should be updated to version2"); + + Timestamp newUpdateTime = rs.getTimestamp("update_time"); + assertTrue( + newUpdateTime.after(firstUpdateTime) || newUpdateTime.equals(firstUpdateTime), + "update_time should be updated"); + + assertFalse(rs.next(), "Should only have one app state row"); + } + } + } + } + + @Test + public void testAppStatesTablePrimaryKey() throws SQLException { + String sessionId = "session-app-pk-" + System.currentTimeMillis(); + + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put("app:key", "value"); + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId).blockingGet(); + + // Verify only one row per app_name + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = "SELECT COUNT(*) FROM app_states WHERE app_name = ?"; + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, TEST_APP_NAME); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next()); + assertEquals(1, rs.getInt(1), "Should have exactly one row per app_name"); + } + } + } + } + + // ==================== USER_STATES TABLE TESTS ==================== + + @Test + public void testUserStatesTableCreation() throws SQLException { + String sessionId = "session-user-state-" + System.currentTimeMillis(); + + // Create session with user state + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put("user:language", "English"); + state.put("user:timezone", "UTC"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId).blockingGet(); + + // Verify in user_states table + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = + "SELECT app_name, user_id, state, update_time FROM user_states WHERE app_name = ? AND user_id = ?"; + + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, TEST_APP_NAME); + stmt.setString(2, TEST_USER_ID); + + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next(), "User state should exist in user_states table"); + + assertEquals(TEST_APP_NAME, rs.getString("app_name")); + assertEquals(TEST_USER_ID, rs.getString("user_id")); + + String stateJson = rs.getString("state"); + assertNotNull(stateJson); + assertTrue(stateJson.contains("language")); + assertTrue(stateJson.contains("English")); + assertTrue(stateJson.contains("timezone")); + assertTrue(stateJson.contains("UTC")); + + Timestamp updateTime = rs.getTimestamp("update_time"); + assertNotNull(updateTime, "update_time should not be null"); + } + } + } + } + + @Test + public void testUserStatesTableUpsert() throws SQLException { + String sessionId1 = "session-user-upsert-1-" + System.currentTimeMillis(); + String sessionId2 = "session-user-upsert-2-" + System.currentTimeMillis(); + + // Create first session with user state + ConcurrentHashMap state1 = new ConcurrentHashMap<>(); + state1.put("user:preference", "dark"); + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state1, sessionId1).blockingGet(); + + Timestamp firstUpdateTime; + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = "SELECT update_time FROM user_states WHERE app_name = ? AND user_id = ?"; + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, TEST_APP_NAME); + stmt.setString(2, TEST_USER_ID); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next()); + firstUpdateTime = rs.getTimestamp("update_time"); + } + } + } + + // Wait to ensure timestamp difference + try { + Thread.sleep(100); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + + // Create second session with updated user state + ConcurrentHashMap state2 = new ConcurrentHashMap<>(); + state2.put("user:preference", "light"); + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state2, sessionId2).blockingGet(); + + // Verify user state was updated (not duplicated) + // Spanner doesn't support GROUP BY on JSON (STRING) columns + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = + "SELECT state, update_time FROM user_states WHERE app_name = ? AND user_id = ?"; + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, TEST_APP_NAME); + stmt.setString(2, TEST_USER_ID); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next(), "Should have user state"); + + String stateJson = rs.getString("state"); + assertTrue(stateJson.contains("light"), "State should be updated to 'light'"); + + Timestamp newUpdateTime = rs.getTimestamp("update_time"); + assertTrue( + newUpdateTime.after(firstUpdateTime) || newUpdateTime.equals(firstUpdateTime), + "update_time should be updated"); + + assertFalse(rs.next(), "Should only have one user state row"); + } + } + } + } + + @Test + public void testUserStatesTablePrimaryKey() throws SQLException { + String sessionId = "session-user-pk-" + System.currentTimeMillis(); + + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put("user:key", "value"); + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId).blockingGet(); + + // Verify only one row per (app_name, user_id) combination + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = "SELECT COUNT(*) FROM user_states WHERE app_name = ? AND user_id = ?"; + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, TEST_APP_NAME); + stmt.setString(2, TEST_USER_ID); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next()); + assertEquals(1, rs.getInt(1), "Should have exactly one row per (app_name, user_id)"); + } + } + } + } + + @Test + public void testUserStatesTableIsolationBetweenUsers() throws SQLException { + String user1 = "user-1-" + System.currentTimeMillis(); + String user2 = "user-2-" + System.currentTimeMillis(); + + // Create sessions for two different users + ConcurrentHashMap state1 = new ConcurrentHashMap<>(); + state1.put("user:language", "French"); + sessionService.createSession(TEST_APP_NAME, user1, state1, "session-1").blockingGet(); + + ConcurrentHashMap state2 = new ConcurrentHashMap<>(); + state2.put("user:language", "Spanish"); + sessionService.createSession(TEST_APP_NAME, user2, state2, "session-2").blockingGet(); + + // Verify both users have separate state + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + String query = + "SELECT user_id, state FROM user_states WHERE app_name = ? AND user_id IN (?, ?)"; + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setString(1, TEST_APP_NAME); + stmt.setString(2, user1); + stmt.setString(3, user2); + try (ResultSet rs = stmt.executeQuery()) { + int count = 0; + while (rs.next()) { + String userId = rs.getString("user_id"); + String stateJson = rs.getString("state"); + + if (userId.equals(user1)) { + assertTrue(stateJson.contains("French"), "User 1 should have French"); + } else if (userId.equals(user2)) { + assertTrue(stateJson.contains("Spanish"), "User 2 should have Spanish"); + } + count++; + } + assertEquals(2, count, "Should have state for both users"); + } + } + } + } + + // ==================== CROSS-TABLE INTEGRATION TESTS ==================== + + @Test + public void testAllTablesIntegration() throws SQLException { + String sessionId = "session-integration-" + System.currentTimeMillis(); + + // Create session with all state types + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put("session_key", "session_value"); // session state + state.put("app:api_key", "app-12345"); // app state + state.put("user:theme", "dark"); // user state + + Session session = + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId).blockingGet(); + + // Add event + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("integration-test") + .content(Content.fromParts(Part.fromText("Integration test event"))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + sessionService.appendEvent(session, event).blockingGet(); + + // Verify all tables have data + try (Connection conn = DriverManager.getConnection(TEST_DB_URL)) { + // Check sessions table + String sessionQuery = "SELECT COUNT(*) FROM sessions WHERE id = ?"; + try (PreparedStatement stmt = conn.prepareStatement(sessionQuery)) { + stmt.setString(1, sessionId); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next()); + assertEquals(1, rs.getInt(1), "Session should exist"); + } + } + + // Check events table + String eventQuery = "SELECT COUNT(*) FROM events WHERE session_id = ?"; + try (PreparedStatement stmt = conn.prepareStatement(eventQuery)) { + stmt.setString(1, sessionId); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next()); + assertTrue(rs.getInt(1) > 0, "Should have at least one event"); + } + } + + // Check app_states table + String appQuery = "SELECT state FROM app_states WHERE app_name = ?"; + try (PreparedStatement stmt = conn.prepareStatement(appQuery)) { + stmt.setString(1, TEST_APP_NAME); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next(), "App state should exist"); + String appState = rs.getString("state"); + assertTrue(appState.contains("api_key"), "App state should contain api_key"); + } + } + + // Check user_states table + String userQuery = "SELECT state FROM user_states WHERE app_name = ? AND user_id = ?"; + try (PreparedStatement stmt = conn.prepareStatement(userQuery)) { + stmt.setString(1, TEST_APP_NAME); + stmt.setString(2, TEST_USER_ID); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next(), "User state should exist"); + String userState = rs.getString("state"); + assertTrue(userState.contains("theme"), "User state should contain theme"); + } + } + } + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/StateDeltaTest.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/StateDeltaTest.java new file mode 100644 index 000000000..dc4467b4b --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/StateDeltaTest.java @@ -0,0 +1,400 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +import com.google.adk.events.Event; +import com.google.adk.events.EventActions; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import java.time.Instant; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import org.flywaydb.core.Flyway; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class StateDeltaTest { + + private static final String TEST_DB_URL = + "jdbc:h2:mem:jdbc_delta_test;DB_CLOSE_DELAY=-1;USER=sa;PASSWORD="; + private static final String TEST_APP_NAME = "delta-test-app"; + private static final String TEST_USER_ID = "delta-user"; + + private DatabaseSessionService sessionService; + + @BeforeEach + public void setUp() { + Flyway flyway = + Flyway.configure() + .dataSource(TEST_DB_URL, null, null) + .locations("classpath:db/migration/h2") + .cleanDisabled(false) + .load(); + flyway.clean(); + flyway.migrate(); + + sessionService = new DatabaseSessionService(TEST_DB_URL); + } + + @AfterEach + public void tearDown() { + if (sessionService != null) { + sessionService.close(); + } + } + + @Test + public void testStateDeltaInEvent() { + String sessionId = "delta-event-test"; + + ConcurrentHashMap initialState = new ConcurrentHashMap<>(); + initialState.put("counter", 0); + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, initialState, sessionId) + .blockingGet(); + + ConcurrentHashMap delta = new ConcurrentHashMap<>(); + delta.put("counter", 1); + delta.put("new_field", "added"); + + EventActions actions = EventActions.builder().stateDelta(delta).build(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Test"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(actions) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, event).blockingGet(); + Session updated = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(updated); + assertEquals(1, updated.state().get("counter")); + assertEquals("added", updated.state().get("new_field")); + } + + @Test + public void testAppStateDeltaInEvent() { + String sessionId = "app-delta-test"; + + ConcurrentHashMap initialState = new ConcurrentHashMap<>(); + initialState.put(State.APP_PREFIX + "app_counter", 0); + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, initialState, sessionId) + .blockingGet(); + + ConcurrentHashMap delta = new ConcurrentHashMap<>(); + delta.put(State.APP_PREFIX + "app_counter", 10); + delta.put(State.APP_PREFIX + "app_field", "app_value"); + + EventActions actions = EventActions.builder().stateDelta(delta).build(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Test"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(actions) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, event).blockingGet(); + Session updated = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(updated); + assertEquals(10, updated.state().get(State.APP_PREFIX + "app_counter")); + assertEquals("app_value", updated.state().get(State.APP_PREFIX + "app_field")); + } + + @Test + public void testUserStateDeltaInEvent() { + String sessionId = "user-delta-test"; + + ConcurrentHashMap initialState = new ConcurrentHashMap<>(); + initialState.put(State.USER_PREFIX + "user_counter", 0); + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, initialState, sessionId) + .blockingGet(); + + ConcurrentHashMap delta = new ConcurrentHashMap<>(); + delta.put(State.USER_PREFIX + "user_counter", 5); + delta.put(State.USER_PREFIX + "user_field", "user_value"); + + EventActions actions = EventActions.builder().stateDelta(delta).build(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Test"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(actions) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, event).blockingGet(); + Session updated = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(updated); + assertEquals(5, updated.state().get(State.USER_PREFIX + "user_counter")); + assertEquals("user_value", updated.state().get(State.USER_PREFIX + "user_field")); + } + + @Test + public void testMixedStateDeltaInEvent() { + String sessionId = "mixed-delta-test"; + + ConcurrentHashMap initialState = new ConcurrentHashMap<>(); + initialState.put(State.APP_PREFIX + "app_value", "initial_app"); + initialState.put(State.USER_PREFIX + "user_value", "initial_user"); + initialState.put("session_value", "initial_session"); + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, initialState, sessionId) + .blockingGet(); + + ConcurrentHashMap delta = new ConcurrentHashMap<>(); + delta.put(State.APP_PREFIX + "app_value", "updated_app"); + delta.put(State.USER_PREFIX + "user_value", "updated_user"); + delta.put("session_value", "updated_session"); + + EventActions actions = EventActions.builder().stateDelta(delta).build(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Test"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(actions) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, event).blockingGet(); + Session updated = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(updated); + assertEquals("updated_app", updated.state().get(State.APP_PREFIX + "app_value")); + assertEquals("updated_user", updated.state().get(State.USER_PREFIX + "user_value")); + assertEquals("updated_session", updated.state().get("session_value")); + } + + @Test + public void testStateRemovalViaRemoved() { + String sessionId = "session-removal-test"; + + ConcurrentHashMap initialState = new ConcurrentHashMap<>(); + initialState.put("key_to_remove", "value"); + initialState.put("key_to_keep", "keep_this"); + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, initialState, sessionId) + .blockingGet(); + + ConcurrentHashMap delta = new ConcurrentHashMap<>(); + delta.put("key_to_remove", State.REMOVED); + + EventActions actions = EventActions.builder().stateDelta(delta).build(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Test"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(actions) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, event).blockingGet(); + Session updated = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(updated); + assertFalse(updated.state().containsKey("key_to_remove")); + assertEquals("keep_this", updated.state().get("key_to_keep")); + } + + @Test + public void testAppStateRemovalViaRemoved() { + String sessionId = "app-removal-test"; + + ConcurrentHashMap initialState = new ConcurrentHashMap<>(); + initialState.put(State.APP_PREFIX + "deprecated", "old_value"); + initialState.put(State.APP_PREFIX + "current", "keep_this"); + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, initialState, sessionId) + .blockingGet(); + + ConcurrentHashMap delta = new ConcurrentHashMap<>(); + delta.put(State.APP_PREFIX + "deprecated", State.REMOVED); + + EventActions actions = EventActions.builder().stateDelta(delta).build(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Test"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(actions) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, event).blockingGet(); + Session updated = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(updated); + assertFalse(updated.state().containsKey(State.APP_PREFIX + "deprecated")); + assertEquals("keep_this", updated.state().get(State.APP_PREFIX + "current")); + } + + @Test + public void testUserStateRemovalViaRemoved() { + String sessionId = "user-removal-test"; + + ConcurrentHashMap initialState = new ConcurrentHashMap<>(); + initialState.put(State.USER_PREFIX + "old_pref", "remove_me"); + initialState.put(State.USER_PREFIX + "new_pref", "keep_this"); + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, initialState, sessionId) + .blockingGet(); + + ConcurrentHashMap delta = new ConcurrentHashMap<>(); + delta.put(State.USER_PREFIX + "old_pref", State.REMOVED); + + EventActions actions = EventActions.builder().stateDelta(delta).build(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Test"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(actions) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, event).blockingGet(); + Session updated = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(updated); + assertFalse(updated.state().containsKey(State.USER_PREFIX + "old_pref")); + assertEquals("keep_this", updated.state().get(State.USER_PREFIX + "new_pref")); + } + + @Test + public void testMixedStateRemovalViaRemoved() { + String sessionId = "mixed-removal-test"; + + ConcurrentHashMap initialState = new ConcurrentHashMap<>(); + initialState.put(State.APP_PREFIX + "app_deprecated", "remove"); + initialState.put(State.APP_PREFIX + "app_current", "keep"); + initialState.put(State.USER_PREFIX + "user_old", "remove"); + initialState.put(State.USER_PREFIX + "user_new", "keep"); + initialState.put("session_temp", "remove"); + initialState.put("session_data", "keep"); + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, initialState, sessionId) + .blockingGet(); + + ConcurrentHashMap delta = new ConcurrentHashMap<>(); + delta.put(State.APP_PREFIX + "app_deprecated", State.REMOVED); + delta.put(State.USER_PREFIX + "user_old", State.REMOVED); + delta.put("session_temp", State.REMOVED); + + EventActions actions = EventActions.builder().stateDelta(delta).build(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Test"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(actions) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, event).blockingGet(); + Session updated = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(updated); + assertFalse(updated.state().containsKey(State.APP_PREFIX + "app_deprecated")); + assertFalse(updated.state().containsKey(State.USER_PREFIX + "user_old")); + assertFalse(updated.state().containsKey("session_temp")); + assertEquals("keep", updated.state().get(State.APP_PREFIX + "app_current")); + assertEquals("keep", updated.state().get(State.USER_PREFIX + "user_new")); + assertEquals("keep", updated.state().get("session_data")); + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/StateManagementTest.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/StateManagementTest.java new file mode 100644 index 000000000..66301903c --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/StateManagementTest.java @@ -0,0 +1,392 @@ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class StateManagementTest { + + private static final String TEST_DB_URL = + "jdbc:h2:mem:state_test;DB_CLOSE_DELAY=-1;MODE=PostgreSQL"; + private static final String TEST_APP_NAME = "state-test-app"; + private static final String TEST_USER_ID_1 = "user-1"; + private static final String TEST_USER_ID_2 = "user-2"; + + private DatabaseSessionService sessionService; + + @BeforeEach + public void setUp() { + sessionService = new DatabaseSessionService(TEST_DB_URL); + } + + @AfterEach + public void tearDown() { + if (sessionService != null) { + sessionService.close(); + } + } + + @Test + public void testAppStateSharing() { + String sessionId1 = "session-1"; + String sessionId2 = "session-2"; + + ConcurrentHashMap state1 = new ConcurrentHashMap<>(); + state1.put(State.APP_PREFIX + "global_setting", "shared_value"); + state1.put("local", "private_value_1"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID_1, state1, sessionId1).blockingGet(); + + ConcurrentHashMap state2 = new ConcurrentHashMap<>(); + state2.put(State.APP_PREFIX + "global_setting", "updated_value"); + state2.put("local", "private_value_2"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID_2, state2, sessionId2).blockingGet(); + + Session retrieved1 = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId1, Optional.empty()) + .blockingGet(); + Session retrieved2 = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_2, sessionId2, Optional.empty()) + .blockingGet(); + + assertNotNull(retrieved1); + assertNotNull(retrieved2); + + assertEquals("updated_value", retrieved1.state().get(State.APP_PREFIX + "global_setting")); + assertEquals("updated_value", retrieved2.state().get(State.APP_PREFIX + "global_setting")); + + assertEquals("private_value_1", retrieved1.state().get("local")); + assertEquals("private_value_2", retrieved2.state().get("local")); + } + + @Test + public void testUserStateSharing() { + String sessionId1 = "user-session-1"; + String sessionId2 = "user-session-2"; + + ConcurrentHashMap state1 = new ConcurrentHashMap<>(); + state1.put(State.USER_PREFIX + "preference", "dark_mode"); + state1.put("data", "session_specific_1"); + + ConcurrentHashMap state2 = new ConcurrentHashMap<>(); + state2.put(State.USER_PREFIX + "preference", "light_mode"); + state2.put("data", "session_specific_2"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID_1, state1, sessionId1).blockingGet(); + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID_1, state2, sessionId2).blockingGet(); + + Session session1 = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId1, Optional.empty()) + .blockingGet(); + Session session2 = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId2, Optional.empty()) + .blockingGet(); + + assertNotNull(session1); + assertNotNull(session2); + + assertEquals("light_mode", session1.state().get(State.USER_PREFIX + "preference")); + assertEquals("light_mode", session2.state().get(State.USER_PREFIX + "preference")); + + assertEquals("session_specific_1", session1.state().get("data")); + assertEquals("session_specific_2", session2.state().get("data")); + } + + @Test + public void testSessionStateIsolation() { + String sessionId1 = "isolated-1"; + String sessionId2 = "isolated-2"; + + ConcurrentHashMap state1 = new ConcurrentHashMap<>(); + state1.put("private_key", "value_1"); + + ConcurrentHashMap state2 = new ConcurrentHashMap<>(); + state2.put("private_key", "value_2"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID_1, state1, sessionId1).blockingGet(); + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID_1, state2, sessionId2).blockingGet(); + + Session session1 = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId1, Optional.empty()) + .blockingGet(); + Session session2 = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId2, Optional.empty()) + .blockingGet(); + + assertNotNull(session1); + assertNotNull(session2); + + assertEquals("value_1", session1.state().get("private_key")); + assertEquals("value_2", session2.state().get("private_key")); + } + + @Test + public void testStatePriorityMerging() { + String sessionId = "priority-test"; + + ConcurrentHashMap initialState = new ConcurrentHashMap<>(); + initialState.put("_app_key", "app_value"); + initialState.put("_user_key", "user_value"); + initialState.put("key", "session_value"); + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID_1, initialState, sessionId) + .blockingGet(); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(retrieved); + assertEquals("app_value", retrieved.state().get("_app_key")); + assertEquals("user_value", retrieved.state().get("_user_key")); + assertEquals("session_value", retrieved.state().get("key")); + } + + @Test + public void testTempStateIsIgnored() { + String sessionId = "temp-test"; + + ConcurrentHashMap initialState = new ConcurrentHashMap<>(); + initialState.put("temp:ignored", "should_not_persist"); + initialState.put("persisted", "should_persist"); + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID_1, initialState, sessionId) + .blockingGet(); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(retrieved); + assertEquals("should_persist", retrieved.state().get("persisted")); + assertEquals(null, retrieved.state().get("temp:ignored")); + } + + @Test + public void testStateMerge_putAllDoesNotLoseData() { + String sessionId1 = "merge-test-1"; + String sessionId2 = "merge-test-2"; + + ConcurrentHashMap state1 = new ConcurrentHashMap<>(); + state1.put(State.APP_PREFIX + "key1", "value1"); + state1.put(State.APP_PREFIX + "key2", "value2"); + state1.put(State.APP_PREFIX + "key3", "value3"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID_1, state1, sessionId1).blockingGet(); + + ConcurrentHashMap state2 = new ConcurrentHashMap<>(); + state2.put(State.APP_PREFIX + "key4", "value4"); + state2.put(State.APP_PREFIX + "key5", "value5"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID_1, state2, sessionId2).blockingGet(); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId2, Optional.empty()) + .blockingGet(); + + assertNotNull(retrieved); + assertEquals("value1", retrieved.state().get(State.APP_PREFIX + "key1")); + assertEquals("value2", retrieved.state().get(State.APP_PREFIX + "key2")); + assertEquals("value3", retrieved.state().get(State.APP_PREFIX + "key3")); + assertEquals("value4", retrieved.state().get(State.APP_PREFIX + "key4")); + assertEquals("value5", retrieved.state().get(State.APP_PREFIX + "key5")); + } + + @Test + public void testStateMerge_nestedObjectsPreserved() { + String sessionId1 = "nested-merge-1"; + String sessionId2 = "nested-merge-2"; + + ConcurrentHashMap nestedMap1 = new ConcurrentHashMap<>(); + nestedMap1.put("nested_key_1", "nested_value_1"); + nestedMap1.put("nested_key_2", 42); + + ConcurrentHashMap state1 = new ConcurrentHashMap<>(); + state1.put(State.APP_PREFIX + "config", nestedMap1); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID_1, state1, sessionId1).blockingGet(); + + ConcurrentHashMap nestedMap2 = new ConcurrentHashMap<>(); + nestedMap2.put("another_nested_key", "another_value"); + + ConcurrentHashMap state2 = new ConcurrentHashMap<>(); + state2.put(State.APP_PREFIX + "other_config", nestedMap2); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID_1, state2, sessionId2).blockingGet(); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId2, Optional.empty()) + .blockingGet(); + + assertNotNull(retrieved); + assertNotNull(retrieved.state().get(State.APP_PREFIX + "config")); + assertNotNull(retrieved.state().get(State.APP_PREFIX + "other_config")); + } + + @Test + public void testStateMerge_overwriteExistingKeys() { + String sessionId1 = "overwrite-1"; + String sessionId2 = "overwrite-2"; + + ConcurrentHashMap state1 = new ConcurrentHashMap<>(); + state1.put(State.APP_PREFIX + "shared_key", "original_value"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID_1, state1, sessionId1).blockingGet(); + + ConcurrentHashMap state2 = new ConcurrentHashMap<>(); + state2.put(State.APP_PREFIX + "shared_key", "updated_value"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID_1, state2, sessionId2).blockingGet(); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId2, Optional.empty()) + .blockingGet(); + + assertNotNull(retrieved); + assertEquals("updated_value", retrieved.state().get(State.APP_PREFIX + "shared_key")); + } + + @Test + public void testStateMerge_userStateDoesNotLoseData() { + String sessionId1 = "user-merge-1"; + String sessionId2 = "user-merge-2"; + + ConcurrentHashMap state1 = new ConcurrentHashMap<>(); + state1.put(State.USER_PREFIX + "pref1", "value1"); + state1.put(State.USER_PREFIX + "pref2", "value2"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID_1, state1, sessionId1).blockingGet(); + + ConcurrentHashMap state2 = new ConcurrentHashMap<>(); + state2.put(State.USER_PREFIX + "pref3", "value3"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID_1, state2, sessionId2).blockingGet(); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId2, Optional.empty()) + .blockingGet(); + + assertNotNull(retrieved); + assertEquals("value1", retrieved.state().get(State.USER_PREFIX + "pref1")); + assertEquals("value2", retrieved.state().get(State.USER_PREFIX + "pref2")); + assertEquals("value3", retrieved.state().get(State.USER_PREFIX + "pref3")); + } + + @Test + public void testStateMerge_sessionStateRemainsIsolated() { + String sessionId1 = "session-isolated-1"; + String sessionId2 = "session-isolated-2"; + + ConcurrentHashMap state1 = new ConcurrentHashMap<>(); + state1.put("session_key1", "session_value1"); + state1.put(State.APP_PREFIX + "app_key", "shared"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID_1, state1, sessionId1).blockingGet(); + + ConcurrentHashMap state2 = new ConcurrentHashMap<>(); + state2.put("session_key2", "session_value2"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID_1, state2, sessionId2).blockingGet(); + + Session retrieved1 = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId1, Optional.empty()) + .blockingGet(); + Session retrieved2 = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId2, Optional.empty()) + .blockingGet(); + + assertNotNull(retrieved1); + assertNotNull(retrieved2); + + assertEquals("session_value1", retrieved1.state().get("session_key1")); + assertEquals(null, retrieved1.state().get("session_key2")); + + assertEquals(null, retrieved2.state().get("session_key1")); + assertEquals("session_value2", retrieved2.state().get("session_key2")); + + assertEquals("shared", retrieved1.state().get(State.APP_PREFIX + "app_key")); + assertEquals("shared", retrieved2.state().get(State.APP_PREFIX + "app_key")); + } + + @Test + public void testStateMerge_largeStateDoesNotLoseData() { + String sessionId1 = "large-state-1"; + String sessionId2 = "large-state-2"; + + ConcurrentHashMap state1 = new ConcurrentHashMap<>(); + for (int i = 0; i < 50; i++) { + state1.put(State.APP_PREFIX + "key_" + i, "value_" + i); + } + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID_1, state1, sessionId1).blockingGet(); + + ConcurrentHashMap state2 = new ConcurrentHashMap<>(); + for (int i = 50; i < 100; i++) { + state2.put(State.APP_PREFIX + "key_" + i, "value_" + i); + } + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID_1, state2, sessionId2).blockingGet(); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId2, Optional.empty()) + .blockingGet(); + + assertNotNull(retrieved); + + for (int i = 0; i < 100; i++) { + assertEquals( + "value_" + i, + retrieved.state().get(State.APP_PREFIX + "key_" + i), + "Key " + i + " should not be lost during merge"); + } + } + + @Test + public void testStateMerge_roundTripSerialization() { + String sessionId = "roundtrip-test"; + + ConcurrentHashMap originalState = new ConcurrentHashMap<>(); + originalState.put(State.APP_PREFIX + "string_key", "string_value"); + originalState.put(State.APP_PREFIX + "int_key", 42); + originalState.put(State.APP_PREFIX + "double_key", 3.14); + originalState.put(State.APP_PREFIX + "boolean_key", true); + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID_1, originalState, sessionId) + .blockingGet(); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(retrieved); + assertEquals("string_value", retrieved.state().get(State.APP_PREFIX + "string_key")); + assertEquals(42, retrieved.state().get(State.APP_PREFIX + "int_key")); + assertEquals(3.14, retrieved.state().get(State.APP_PREFIX + "double_key")); + assertEquals(true, retrieved.state().get(State.APP_PREFIX + "boolean_key")); + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/StatePrefixHandlingTest.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/StatePrefixHandlingTest.java new file mode 100644 index 000000000..5dcb1a2ff --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/StatePrefixHandlingTest.java @@ -0,0 +1,334 @@ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.*; + +import com.google.adk.events.Event; +import com.google.adk.events.EventActions; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.ResultSet; +import java.sql.Statement; +import java.time.Instant; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +/** + * Integration test to verify that app/user state is stored WITHOUT prefixes in the database but + * retrieved WITH prefixes in the session state. + * + *

This ensures compatibility with: + * + *

    + *
  • Python DatabaseSessionService implementation + *
  • Java InMemorySessionService implementation + *
  • Proper state isolation and namespace handling + *
+ */ +public class StatePrefixHandlingTest { + + private static final String TEST_DB_URL = "jdbc:h2:mem:testdb_prefix;DB_CLOSE_DELAY=-1"; + private static final String TEST_APP_NAME = "test-app"; + private static final String TEST_USER_ID_1 = "user1"; + private static final String TEST_USER_ID_2 = "user2"; + + private DatabaseSessionService sessionService; + + @BeforeEach + public void setUp() { + sessionService = new DatabaseSessionService(TEST_DB_URL); + } + + @AfterEach + public void tearDown() { + if (sessionService != null) { + try (Connection conn = DriverManager.getConnection(TEST_DB_URL); + Statement stmt = conn.createStatement()) { + stmt.execute("DELETE FROM events"); + stmt.execute("DELETE FROM sessions"); + stmt.execute("DELETE FROM app_states"); + stmt.execute("DELETE FROM user_states"); + } catch (Exception e) { + } + sessionService.close(); + } + } + + @Test + public void testAppStatePrefixStrippedInDatabase() throws Exception { + String sessionId = "app-prefix-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID_1, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + ConcurrentHashMap delta = new ConcurrentHashMap<>(); + delta.put(State.APP_PREFIX + "counter", 42); + delta.put(State.APP_PREFIX + "theme", "dark"); + + EventActions actions = EventActions.builder().stateDelta(delta).build(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Test"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(actions) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, event).blockingGet(); + + try (Connection conn = DriverManager.getConnection(TEST_DB_URL); + Statement stmt = conn.createStatement(); + ResultSet rs = + stmt.executeQuery( + "SELECT state FROM app_states WHERE app_name = '" + TEST_APP_NAME + "'")) { + + assertTrue(rs.next(), "App state should exist in database"); + String stateJson = rs.getString("state"); + assertNotNull(stateJson); + assertFalse( + stateJson.contains("\"app:counter\""), "Database should NOT contain 'app:' prefix"); + assertFalse(stateJson.contains("\"app:theme\""), "Database should NOT contain 'app:' prefix"); + assertTrue(stateJson.contains("\"counter\""), "Database should contain unprefixed 'counter'"); + assertTrue(stateJson.contains("\"theme\""), "Database should contain unprefixed 'theme'"); + } + } + + @Test + public void testUserStatePrefixStrippedInDatabase() throws Exception { + String sessionId = "user-prefix-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID_1, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + ConcurrentHashMap delta = new ConcurrentHashMap<>(); + delta.put(State.USER_PREFIX + "preference", "enabled"); + delta.put(State.USER_PREFIX + "language", "en"); + + EventActions actions = EventActions.builder().stateDelta(delta).build(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Test"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(actions) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, event).blockingGet(); + + try (Connection conn = DriverManager.getConnection(TEST_DB_URL); + Statement stmt = conn.createStatement(); + ResultSet rs = + stmt.executeQuery( + "SELECT state FROM user_states WHERE app_name = '" + + TEST_APP_NAME + + "' AND user_id = '" + + TEST_USER_ID_1 + + "'")) { + + assertTrue(rs.next(), "User state should exist in database"); + String stateJson = rs.getString("state"); + assertNotNull(stateJson); + assertFalse( + stateJson.contains("\"user:preference\""), "Database should NOT contain 'user:' prefix"); + assertFalse( + stateJson.contains("\"user:language\""), "Database should NOT contain 'user:' prefix"); + assertTrue( + stateJson.contains("\"preference\""), "Database should contain unprefixed 'preference'"); + assertTrue( + stateJson.contains("\"language\""), "Database should contain unprefixed 'language'"); + } + } + + @Test + public void testSessionStatePrefixAddedDuringRetrieval() { + String sessionId = "retrieval-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID_1, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + ConcurrentHashMap delta = new ConcurrentHashMap<>(); + delta.put(State.APP_PREFIX + "global_value", "shared"); + delta.put(State.USER_PREFIX + "user_value", "personal"); + delta.put("session_value", "private"); + + EventActions actions = EventActions.builder().stateDelta(delta).build(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Test"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(actions) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, event).blockingGet(); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(retrieved); + assertEquals( + "shared", + retrieved.state().get(State.APP_PREFIX + "global_value"), + "App state should have 'app:' prefix"); + assertEquals( + "personal", + retrieved.state().get(State.USER_PREFIX + "user_value"), + "User state should have 'user:' prefix"); + assertEquals( + "private", retrieved.state().get("session_value"), "Session state should NOT have prefix"); + } + + @Test + public void testAppStateSharedAcrossUsers() { + String sessionId1 = "session1"; + String sessionId2 = "session2"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID_1, new ConcurrentHashMap<>(), sessionId1) + .blockingGet(); + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID_2, new ConcurrentHashMap<>(), sessionId2) + .blockingGet(); + + ConcurrentHashMap delta = new ConcurrentHashMap<>(); + delta.put(State.APP_PREFIX + "feature_flag", true); + + EventActions actions = EventActions.builder().stateDelta(delta).build(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Test"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(actions) + .build(); + + Session session1 = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId1, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session1, event).blockingGet(); + + Session session2 = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_2, sessionId2, Optional.empty()) + .blockingGet(); + + assertTrue( + (Boolean) session2.state().get(State.APP_PREFIX + "feature_flag"), + "User 2 should see app state set by User 1"); + } + + @Test + public void testUserStateIsolatedBetweenUsers() { + String sessionId1 = "session1"; + String sessionId2 = "session2"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID_1, new ConcurrentHashMap<>(), sessionId1) + .blockingGet(); + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID_2, new ConcurrentHashMap<>(), sessionId2) + .blockingGet(); + + ConcurrentHashMap delta = new ConcurrentHashMap<>(); + delta.put(State.USER_PREFIX + "timezone", "UTC"); + + EventActions actions = EventActions.builder().stateDelta(delta).build(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Test"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(actions) + .build(); + + Session session1 = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId1, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session1, event).blockingGet(); + + Session session2 = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_2, sessionId2, Optional.empty()) + .blockingGet(); + + assertNull( + session2.state().get(State.USER_PREFIX + "timezone"), + "User 2 should NOT see user state set by User 1"); + } + + @Test + public void testUserStateSharedAcrossSessionsForSameUser() { + String sessionId1 = "session1"; + String sessionId2 = "session2"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID_1, new ConcurrentHashMap<>(), sessionId1) + .blockingGet(); + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID_1, new ConcurrentHashMap<>(), sessionId2) + .blockingGet(); + + ConcurrentHashMap delta = new ConcurrentHashMap<>(); + delta.put(State.USER_PREFIX + "notification_enabled", true); + + EventActions actions = EventActions.builder().stateDelta(delta).build(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Test"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(actions) + .build(); + + Session session1 = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId1, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session1, event).blockingGet(); + + Session session2 = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId2, Optional.empty()) + .blockingGet(); + + assertTrue( + (Boolean) session2.state().get(State.USER_PREFIX + "notification_enabled"), + "Same user should see user state across different sessions"); + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/testing/TestDatabaseConfig.java b/contrib/database-session-service/src/test/java/com/google/adk/testing/TestDatabaseConfig.java new file mode 100644 index 000000000..b9b113ade --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/testing/TestDatabaseConfig.java @@ -0,0 +1,125 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.testing; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; + +/** + * Centralized configuration for integration test databases. + * + *

These connection strings assume test databases are running via + * scripts/docker-compose.test.yml: + * + *

{@code
+ * docker-compose -f scripts/docker-compose.test.yml up -d
+ * }
+ */ +public final class TestDatabaseConfig { + + private TestDatabaseConfig() {} + + // MySQL Test Database Configuration + public static final String MYSQL_HOST = "localhost"; + public static final int MYSQL_PORT = 3306; + public static final String MYSQL_DATABASE = "adk_test"; + public static final String MYSQL_USER = "adk_user"; + public static final String MYSQL_PASSWORD = "adk_password"; + public static final String MYSQL_JDBC_URL = + String.format( + "jdbc:mysql://%s:%d/%s?user=%s&password=%s&useSSL=false&allowPublicKeyRetrieval=true", + MYSQL_HOST, MYSQL_PORT, MYSQL_DATABASE, MYSQL_USER, MYSQL_PASSWORD); + + // PostgreSQL Test Database Configuration + public static final String POSTGRES_HOST = "localhost"; + public static final int POSTGRES_PORT = 5432; + public static final String POSTGRES_DATABASE = "adk_test"; + public static final String POSTGRES_USER = "adk_user"; + public static final String POSTGRES_PASSWORD = "adk_password"; + public static final String POSTGRES_JDBC_URL = + String.format( + "jdbc:postgresql://%s:%d/%s?user=%s&password=%s", + POSTGRES_HOST, POSTGRES_PORT, POSTGRES_DATABASE, POSTGRES_USER, POSTGRES_PASSWORD); + + // Cloud Spanner Emulator Configuration + public static final String SPANNER_HOST = "localhost"; + public static final int SPANNER_PORT = 9010; + public static final String SPANNER_PROJECT = "test-project"; + public static final String SPANNER_INSTANCE = "test-instance"; + public static final String SPANNER_DATABASE = "test-db"; + public static final String SPANNER_JDBC_URL = + String.format( + "jdbc:cloudspanner://%s:%d/projects/%s/instances/%s/databases/%s?autoConfigEmulator=true", + SPANNER_HOST, SPANNER_PORT, SPANNER_PROJECT, SPANNER_INSTANCE, SPANNER_DATABASE); + + /** + * Checks if MySQL test database is available. + * + * @return true if connection succeeds, false otherwise + */ + public static boolean isMySQLAvailable() { + try (Connection conn = DriverManager.getConnection(MYSQL_JDBC_URL)) { + return conn.isValid(2); + } catch (SQLException e) { + return false; + } + } + + /** + * Checks if PostgreSQL test database is available. + * + * @return true if connection succeeds, false otherwise + */ + public static boolean isPostgreSQLAvailable() { + try (Connection conn = DriverManager.getConnection(POSTGRES_JDBC_URL)) { + return conn.isValid(2); + } catch (SQLException e) { + return false; + } + } + + /** + * Checks if Cloud Spanner emulator is available. + * + * @return true if connection succeeds, false otherwise + */ + public static boolean isSpannerAvailable() { + try (Connection conn = DriverManager.getConnection(SPANNER_JDBC_URL)) { + return conn.isValid(2); + } catch (SQLException e) { + return false; + } + } + + /** + * Returns a helpful message for skipped tests when database is not available. + * + * @param databaseName The name of the database (MySQL, PostgreSQL, or Spanner) + * @return A message explaining how to start the database + */ + public static String getDatabaseNotAvailableMessage(String databaseName) { + if ("Spanner".equalsIgnoreCase(databaseName)) { + return "Cloud Spanner emulator not available. Start it with: " + + "docker run -d -p 9010:9010 -p 9020:9020 gcr.io/cloud-spanner-emulator/emulator && " + + "export SPANNER_EMULATOR_HOST=localhost:9010"; + } + return String.format( + "%s test database not available. Start it with: " + + "docker-compose -f docker-compose.test.yml up -d %s-test", + databaseName, databaseName.toLowerCase()); + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/testing/TestLlm.java b/contrib/database-session-service/src/test/java/com/google/adk/testing/TestLlm.java new file mode 100644 index 000000000..aaacf00a0 --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/testing/TestLlm.java @@ -0,0 +1,310 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.testing; + +import static com.google.common.collect.ImmutableList.toImmutableList; + +import com.google.adk.agents.LiveRequest; +import com.google.adk.models.BaseLlm; +import com.google.adk.models.BaseLlmConnection; +import com.google.adk.models.LlmRequest; +import com.google.adk.models.LlmResponse; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import com.google.genai.types.Blob; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.Flowable; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.NoSuchElementException; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Predicate; +import java.util.function.Supplier; +import javax.annotation.Nullable; + +/** + * A test implementation of {@link BaseLlm}. + * + *

Supports providing responses via a sequence of {@link LlmResponse} objects or a {@link + * Supplier} of {@code Flowable}. It also captures all standard and live requests for + * assertion in tests. + */ +public final class TestLlm extends BaseLlm { + private final List llmRequests = Collections.synchronizedList(new ArrayList<>()); + private final List liveRequestHistory = + Collections.synchronizedList(new ArrayList<>()); + + private final List responseSequence; + private final AtomicInteger responseIndex = new AtomicInteger(0); + + private final Supplier> responsesSupplier; + private final Optional error; + + private TestLlm( + @Nullable List responses, + @Nullable Supplier> responsesSupplier, + @Nullable Throwable error) { + super("test-llm"); + this.responseSequence = responses; + this.responsesSupplier = responsesSupplier; + this.error = Optional.ofNullable(error); + } + + /** + * Constructs a TestLlm that serves responses sequentially from the provided list. + * + * @param responses A list of LlmResponse objects to be served in order. Can be null or empty. + */ + public TestLlm(@Nullable List responses) { + this(responses == null ? ImmutableList.of() : ImmutableList.copyOf(responses), null, null); + } + + /** + * Constructs a TestLlm that uses the provided supplier to get responses. + * + * @param responsesSupplier A supplier that provides a Flowable of LlmResponse. + */ + public TestLlm(Supplier> responsesSupplier) { + this(null, responsesSupplier, null); + } + + @CanIgnoreReturnValue + public static TestLlm create(@Nullable List responses, @Nullable Throwable error) { + if (error != null) { + return new TestLlm(ImmutableList.of(), null, error); + } + if (responses == null || responses.isEmpty()) { + return new TestLlm(ImmutableList.of(), null, null); + } + + List llmResponses = new ArrayList<>(); + Object first = responses.get(0); + if (first instanceof LlmResponse) { + // responses is List + for (Object response : responses) { + if (response instanceof LlmResponse llmResponse) { + llmResponses.add(llmResponse); + } else { + throw new IllegalArgumentException("Mixed response types in List"); + } + } + } else if (first instanceof String) { + // responses is List + for (Object item : responses) { + if (item instanceof String string) { + llmResponses.add( + LlmResponse.builder() + .content(Content.builder().parts(ImmutableList.of(Part.fromText(string))).build()) + .build()); + } else { + throw new IllegalArgumentException("Mixed response types in List"); + } + } + } else if (first instanceof Part) { + // responses is List + for (Object item : responses) { + if (item instanceof Part part) { + llmResponses.add( + LlmResponse.builder() + .content(Content.builder().parts(ImmutableList.of(part)).build()) + .build()); + } else { + throw new IllegalArgumentException("Mixed response types in List"); + } + } + } else if (first instanceof List) { + // responses is List> + for (Object item : responses) { + if (item instanceof List) { + List partList = (List) item; + if (!partList.isEmpty() && partList.get(0) instanceof Part) { + llmResponses.add( + LlmResponse.builder() + .content( + Content.builder() + .parts(partList.stream().map(p -> (Part) p).collect(toImmutableList())) + .build()) + .build()); + } else { + throw new IllegalArgumentException("Inner list elements are not Part instances."); + } + } else { + throw new IllegalArgumentException("Mixed response types in List"); + } + } + } else { + throw new IllegalArgumentException("Unsupported response type in List" + first.getClass()); + } + return new TestLlm(llmResponses, null, null); + } + + @CanIgnoreReturnValue + public static TestLlm create(@Nullable List responses) { + return create(responses, null); + } + + @CanIgnoreReturnValue + public static TestLlm create(String... responses) { + return create(Arrays.asList(responses), null); + } + + @CanIgnoreReturnValue + public static TestLlm create(LlmResponse... responses) { + return create(Arrays.asList(responses), null); + } + + @CanIgnoreReturnValue + public static TestLlm create(Part... responses) { + return create(Arrays.asList(responses), null); + } + + @Override + public Flowable generateContent(LlmRequest llmRequest, boolean stream) { + llmRequests.add(llmRequest); + + if (error.isPresent()) { + return Flowable.error(error.get()); + } + + if (this.responseSequence != null) { + // Sequential discrete response mode + int currentIndex = responseIndex.getAndIncrement(); + if (currentIndex < responseSequence.size()) { + LlmResponse nextResponse = responseSequence.get(currentIndex); + return Flowable.just(nextResponse); + } else { + return Flowable.error( + new NoSuchElementException( + "TestLlm (List mode) out of responses. Requested response for LLM call " + + llmRequests.size() + + " (index " + + currentIndex + + ") but only " + + responseSequence.size() + + " were configured.")); + } + } else if (this.responsesSupplier != null) { + // Legacy/streaming supplier mode + return responsesSupplier.get(); + } else { + // Should not happen if constructors are used properly + return Flowable.error(new IllegalStateException("TestLlm not initialized with responses.")); + } + } + + @Override + public BaseLlmConnection connect(LlmRequest llmRequest) { + llmRequests.add(llmRequest); + return new TestLlmConnection(); + } + + public ImmutableList getRequests() { + return ImmutableList.copyOf(llmRequests); + } + + public LlmRequest getLastRequest() { + return Iterables.getLast(llmRequests); + } + + /** Returns an immutable list of all {@link LiveRequest}s sent to the live connection. */ + public ImmutableList getLiveRequestHistory() { + return ImmutableList.copyOf(liveRequestHistory); + } + + public boolean waitForStreamingToolResults(String toolName, int expectedCount, Duration timeout) { + Instant deadline = Instant.now().plus(timeout); + String prefix = "Function " + toolName + " returned:"; + + Predicate isStreamingToolResult = + req -> + req.content() + .filter( + content -> + content.role().orElse("").equals("user") + && content.text() != null + && content.text().startsWith(prefix)) + .isPresent(); + + long currentCount = 0; + while (Instant.now().isBefore(deadline)) { + currentCount = getLiveRequestHistory().stream().filter(isStreamingToolResult).count(); + if (currentCount >= expectedCount) { + return true; + } + try { + Thread.sleep(200); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + return false; + } + } + return false; + } + + /** A test implementation of {@link BaseLlmConnection} for {@link TestLlm}. */ + private final class TestLlmConnection implements BaseLlmConnection { + + @Override + public Completable sendHistory(List history) { + return Completable.complete(); + } + + @Override + public Completable sendContent(Content content) { + liveRequestHistory.add(LiveRequest.builder().content(content).build()); + return Completable.complete(); + } + + @Override + public Completable sendRealtime(Blob blob) { + liveRequestHistory.add(LiveRequest.builder().blob(blob).build()); + return Completable.complete(); + } + + @Override + public Flowable receive() { + if (error.isPresent()) { + return Flowable.error(error.get()); + } + if (responseSequence != null) { + return Flowable.fromIterable(responseSequence); + } else if (responsesSupplier != null) { + return responsesSupplier.get(); + } else { + return Flowable.error(new IllegalStateException("TestLlm not initialized with responses.")); + } + } + + @Override + public void close() { + liveRequestHistory.add(LiveRequest.builder().close(true).build()); + } + + @Override + public void close(Throwable throwable) { + close(); + } + } +} diff --git a/pom.xml b/pom.xml index 6a1aa5af5..250328b88 100644 --- a/pom.xml +++ b/pom.xml @@ -34,6 +34,7 @@ contrib/spring-ai contrib/samples contrib/firestore-session-service + contrib/database-session-service tutorials/city-time-weather tutorials/live-audio-single-agent a2a @@ -72,6 +73,9 @@ 1.4.0 3.9.0 5.4.3 + 1.2.0 + 6.2.1 + 11.17.0 @@ -244,6 +248,31 @@ graphviz-java ${graphviz.version} + + com.zaxxer + HikariCP + ${hikaricp.version} + + + org.flywaydb + flyway-core + ${flyway.version} + + + org.flywaydb + flyway-database-postgresql + ${flyway.version} + + + org.flywaydb + flyway-mysql + ${flyway.version} + + + org.flywaydb + flyway-gcp-spanner + ${flyway.version} + org.eclipse.jdt ecj @@ -558,4 +587,4 @@ https://central.sonatype.com/repository/maven-snapshots/ - \ No newline at end of file +