From 41a0074e70abf0da135d704d2fffeca5e1bee0ff Mon Sep 17 00:00:00 2001 From: Pawel Kosiec Date: Fri, 3 Apr 2026 11:20:52 +0200 Subject: [PATCH 01/13] feat: add Model Serving connector and plugin Add the core Model Serving plugin that provides an authenticated proxy to Databricks Model Serving endpoints. Includes the connector layer (SDK client wrapper) and the plugin layer (Express routes for invoke/stream). Also adds UPSTREAM_ERROR SSE error code for propagating API errors. Signed-off-by: Pawel Kosiec --- .../api/appkit/Interface.EndpointConfig.md | 21 ++ .../appkit/Interface.ServingEndpointEntry.md | 27 ++ .../Interface.ServingEndpointRegistry.md | 5 + .../api/appkit/TypeAlias.ServingFactory.md | 19 + docs/docs/api/appkit/index.md | 4 + docs/docs/api/appkit/typedoc-sidebar.ts | 20 ++ docs/static/appkit-ui/styles.gen.css | 28 +- .../appkit/src/connectors/serving/client.ts | 223 ++++++++++++ .../connectors/serving/tests/client.test.ts | 303 ++++++++++++++++ .../appkit/src/connectors/serving/types.ts | 4 + packages/appkit/src/index.ts | 8 +- packages/appkit/src/plugins/index.ts | 1 + .../appkit/src/plugins/serving/defaults.ts | 26 ++ packages/appkit/src/plugins/serving/index.ts | 2 + .../appkit/src/plugins/serving/manifest.json | 54 +++ .../src/plugins/serving/schema-filter.ts | 127 +++++++ .../appkit/src/plugins/serving/serving.ts | 303 ++++++++++++++++ .../serving/tests/schema-filter.test.ts | 141 ++++++++ .../src/plugins/serving/tests/serving.test.ts | 339 ++++++++++++++++++ packages/appkit/src/plugins/serving/types.ts | 67 ++++ packages/appkit/src/stream/stream-manager.ts | 8 + packages/appkit/src/stream/types.ts | 1 + 22 files changed, 1727 insertions(+), 4 deletions(-) create mode 100644 docs/docs/api/appkit/Interface.EndpointConfig.md create mode 100644 docs/docs/api/appkit/Interface.ServingEndpointEntry.md create mode 100644 docs/docs/api/appkit/Interface.ServingEndpointRegistry.md create mode 100644 docs/docs/api/appkit/TypeAlias.ServingFactory.md create mode 100644 packages/appkit/src/connectors/serving/client.ts create mode 100644 packages/appkit/src/connectors/serving/tests/client.test.ts create mode 100644 packages/appkit/src/connectors/serving/types.ts create mode 100644 packages/appkit/src/plugins/serving/defaults.ts create mode 100644 packages/appkit/src/plugins/serving/index.ts create mode 100644 packages/appkit/src/plugins/serving/manifest.json create mode 100644 packages/appkit/src/plugins/serving/schema-filter.ts create mode 100644 packages/appkit/src/plugins/serving/serving.ts create mode 100644 packages/appkit/src/plugins/serving/tests/schema-filter.test.ts create mode 100644 packages/appkit/src/plugins/serving/tests/serving.test.ts create mode 100644 packages/appkit/src/plugins/serving/types.ts diff --git a/docs/docs/api/appkit/Interface.EndpointConfig.md b/docs/docs/api/appkit/Interface.EndpointConfig.md new file mode 100644 index 00000000..6ee94aa3 --- /dev/null +++ b/docs/docs/api/appkit/Interface.EndpointConfig.md @@ -0,0 +1,21 @@ +# Interface: EndpointConfig + +## Properties + +### env + +```ts +env: string; +``` + +Environment variable holding the endpoint name. + +*** + +### servedModel? + +```ts +optional servedModel: string; +``` + +Target a specific served model (bypasses traffic routing). diff --git a/docs/docs/api/appkit/Interface.ServingEndpointEntry.md b/docs/docs/api/appkit/Interface.ServingEndpointEntry.md new file mode 100644 index 00000000..fa054c3f --- /dev/null +++ b/docs/docs/api/appkit/Interface.ServingEndpointEntry.md @@ -0,0 +1,27 @@ +# Interface: ServingEndpointEntry + +Shape of a single registry entry. + +## Properties + +### chunk + +```ts +chunk: unknown; +``` + +*** + +### request + +```ts +request: Record; +``` + +*** + +### response + +```ts +response: unknown; +``` diff --git a/docs/docs/api/appkit/Interface.ServingEndpointRegistry.md b/docs/docs/api/appkit/Interface.ServingEndpointRegistry.md new file mode 100644 index 00000000..defe5270 --- /dev/null +++ b/docs/docs/api/appkit/Interface.ServingEndpointRegistry.md @@ -0,0 +1,5 @@ +# Interface: ServingEndpointRegistry + +Registry interface for serving endpoint type generation. +Empty by default — augmented by the Vite type generator's `.d.ts` output via module augmentation. +When populated, provides autocomplete for alias names and typed request/response/chunk per endpoint. diff --git a/docs/docs/api/appkit/TypeAlias.ServingFactory.md b/docs/docs/api/appkit/TypeAlias.ServingFactory.md new file mode 100644 index 00000000..9ccafef5 --- /dev/null +++ b/docs/docs/api/appkit/TypeAlias.ServingFactory.md @@ -0,0 +1,19 @@ +# Type Alias: ServingFactory + +```ts +type ServingFactory = keyof ServingEndpointRegistry extends never ? (alias?: string) => ServingEndpointMethods : (alias: K) => ServingEndpointMethods; +``` + +Factory function returned by `AppKit.serving`. + +This is a conditional type that adapts based on whether `ServingEndpointRegistry` +has been populated via module augmentation (generated by `appKitServingTypesPlugin()`): + +- **Registry empty (default):** `(alias?: string) => ServingEndpointMethods` — + accepts any alias string with untyped request/response/chunk. +- **Registry populated:** `(alias: K) => ServingEndpointMethods<...>` — + restricts `alias` to known endpoint keys and infers typed request/response/chunk + from the registry entry. + +Run `appKitServingTypesPlugin()` in your Vite config to generate the registry +augmentation and enable full type safety. diff --git a/docs/docs/api/appkit/index.md b/docs/docs/api/appkit/index.md index b5fb7ce0..f4685e04 100644 --- a/docs/docs/api/appkit/index.md +++ b/docs/docs/api/appkit/index.md @@ -33,6 +33,7 @@ plugin architecture, and React integration. | [BasePluginConfig](Interface.BasePluginConfig.md) | Base configuration interface for AppKit plugins | | [CacheConfig](Interface.CacheConfig.md) | Configuration for the CacheInterceptor. Controls TTL, size limits, storage backend, and probabilistic cleanup. | | [DatabaseCredential](Interface.DatabaseCredential.md) | Database credentials with OAuth token for Postgres connection | +| [EndpointConfig](Interface.EndpointConfig.md) | - | | [GenerateDatabaseCredentialRequest](Interface.GenerateDatabaseCredentialRequest.md) | Request parameters for generating database OAuth credentials | | [ITelemetry](Interface.ITelemetry.md) | Plugin-facing interface for OpenTelemetry instrumentation. Provides a thin abstraction over OpenTelemetry APIs for plugins. | | [LakebasePoolConfig](Interface.LakebasePoolConfig.md) | Configuration for creating a Lakebase connection pool | @@ -42,6 +43,8 @@ plugin architecture, and React integration. | [ResourceEntry](Interface.ResourceEntry.md) | Internal representation of a resource in the registry. Extends ResourceRequirement with resolution state and plugin ownership. | | [ResourceFieldEntry](Interface.ResourceFieldEntry.md) | Defines a single field for a resource. Each field has its own environment variable and optional description. Single-value types use one key (e.g. id); multi-value types (database, secret) use multiple (e.g. instance_name, database_name or scope, key). | | [ResourceRequirement](Interface.ResourceRequirement.md) | Declares a resource requirement for a plugin. Can be defined statically in a manifest or dynamically via getResourceRequirements(). Narrows the generated base: type → ResourceType enum, permission → ResourcePermission union. | +| [ServingEndpointEntry](Interface.ServingEndpointEntry.md) | Shape of a single registry entry. | +| [ServingEndpointRegistry](Interface.ServingEndpointRegistry.md) | Registry interface for serving endpoint type generation. Empty by default — augmented by the Vite type generator's `.d.ts` output via module augmentation. When populated, provides autocomplete for alias names and typed request/response/chunk per endpoint. | | [StreamExecutionSettings](Interface.StreamExecutionSettings.md) | Execution settings for streaming endpoints. Extends PluginExecutionSettings with SSE stream configuration. | | [TelemetryConfig](Interface.TelemetryConfig.md) | OpenTelemetry configuration for AppKit applications | | [ValidationResult](Interface.ValidationResult.md) | Result of validating all registered resources against the environment. | @@ -54,6 +57,7 @@ plugin architecture, and React integration. | [IAppRouter](TypeAlias.IAppRouter.md) | Express router type for plugin route registration | | [PluginData](TypeAlias.PluginData.md) | Tuple of plugin class, config, and name. Created by `toPlugin()` and passed to `createApp()`. | | [ResourcePermission](TypeAlias.ResourcePermission.md) | Union of all possible permission levels across all resource types. | +| [ServingFactory](TypeAlias.ServingFactory.md) | Factory function returned by `AppKit.serving`. | | [ToPlugin](TypeAlias.ToPlugin.md) | Factory function type returned by `toPlugin()`. Accepts optional config and returns a PluginData tuple. | ## Variables diff --git a/docs/docs/api/appkit/typedoc-sidebar.ts b/docs/docs/api/appkit/typedoc-sidebar.ts index 2f17b1d2..91815e3d 100644 --- a/docs/docs/api/appkit/typedoc-sidebar.ts +++ b/docs/docs/api/appkit/typedoc-sidebar.ts @@ -97,6 +97,11 @@ const typedocSidebar: SidebarsConfig = { id: "api/appkit/Interface.DatabaseCredential", label: "DatabaseCredential" }, + { + type: "doc", + id: "api/appkit/Interface.EndpointConfig", + label: "EndpointConfig" + }, { type: "doc", id: "api/appkit/Interface.GenerateDatabaseCredentialRequest", @@ -142,6 +147,16 @@ const typedocSidebar: SidebarsConfig = { id: "api/appkit/Interface.ResourceRequirement", label: "ResourceRequirement" }, + { + type: "doc", + id: "api/appkit/Interface.ServingEndpointEntry", + label: "ServingEndpointEntry" + }, + { + type: "doc", + id: "api/appkit/Interface.ServingEndpointRegistry", + label: "ServingEndpointRegistry" + }, { type: "doc", id: "api/appkit/Interface.StreamExecutionSettings", @@ -183,6 +198,11 @@ const typedocSidebar: SidebarsConfig = { id: "api/appkit/TypeAlias.ResourcePermission", label: "ResourcePermission" }, + { + type: "doc", + id: "api/appkit/TypeAlias.ServingFactory", + label: "ServingFactory" + }, { type: "doc", id: "api/appkit/TypeAlias.ToPlugin", diff --git a/docs/static/appkit-ui/styles.gen.css b/docs/static/appkit-ui/styles.gen.css index 9a9a38eb..a2192039 100644 --- a/docs/static/appkit-ui/styles.gen.css +++ b/docs/static/appkit-ui/styles.gen.css @@ -831,9 +831,6 @@ .max-w-\[calc\(100\%-2rem\)\] { max-width: calc(100% - 2rem); } - .max-w-full { - max-width: 100%; - } .max-w-max { max-width: max-content; } @@ -4514,6 +4511,11 @@ width: calc(var(--spacing) * 5); } } + .\[\&_\[data-slot\=scroll-area-viewport\]\>div\]\:\!block { + & [data-slot=scroll-area-viewport]>div { + display: block !important; + } + } .\[\&_a\]\:underline { & a { text-decoration-line: underline; @@ -4637,11 +4639,26 @@ color: var(--muted-foreground); } } + .\[\&_table\]\:block { + & table { + display: block; + } + } + .\[\&_table\]\:max-w-full { + & table { + max-width: 100%; + } + } .\[\&_table\]\:border-collapse { & table { border-collapse: collapse; } } + .\[\&_table\]\:overflow-x-auto { + & table { + overflow-x: auto; + } + } .\[\&_table\]\:text-xs { & table { font-size: var(--text-xs); @@ -4851,6 +4868,11 @@ width: 100%; } } + .\[\&\>\*\]\:min-w-0 { + &>* { + min-width: calc(var(--spacing) * 0); + } + } .\[\&\>\*\]\:focus-visible\:relative { &>* { &:focus-visible { diff --git a/packages/appkit/src/connectors/serving/client.ts b/packages/appkit/src/connectors/serving/client.ts new file mode 100644 index 00000000..6254426d --- /dev/null +++ b/packages/appkit/src/connectors/serving/client.ts @@ -0,0 +1,223 @@ +import { ApiError, type WorkspaceClient } from "@databricks/sdk-experimental"; +import { createLogger } from "../../logging/logger"; +import type { ServingInvokeOptions } from "./types"; + +const logger = createLogger("connectors:serving"); + +/** + * Builds the invocation URL for a serving endpoint. + * Uses `/served-models/{model}/invocations` when servedModel is specified, + * otherwise `/serving-endpoints/{name}/invocations`. + */ +function buildInvocationUrl( + host: string, + endpointName: string, + servedModel?: string, +): string { + const base = host.startsWith("http") ? host : `https://${host}`; + const encodedName = encodeURIComponent(endpointName); + const path = servedModel + ? `/serving-endpoints/${encodedName}/served-models/${encodeURIComponent(servedModel)}/invocations` + : `/serving-endpoints/${encodedName}/invocations`; + return new URL(path, base).toString(); +} + +/** + * Maps upstream Databricks error status codes to appropriate proxy responses. + */ +function mapUpstreamError( + status: number, + body: string, + headers: Headers, +): ApiError { + const safeMessage = body.length > 500 ? `${body.slice(0, 500)}...` : body; + + let parsed: { message?: string; error?: string } = {}; + try { + parsed = JSON.parse(body); + } catch { + // body is not JSON + } + + const message = parsed.message || parsed.error || safeMessage; + + switch (true) { + case status === 400: + return new ApiError(message, "BAD_REQUEST", 400, undefined, []); + case status === 401 || status === 403: + logger.warn("Authentication failure from serving endpoint: %s", message); + return new ApiError(message, "AUTH_FAILURE", status, undefined, []); + case status === 404: + return new ApiError(message, "NOT_FOUND", 404, undefined, []); + case status === 429: { + const retryAfter = headers.get("retry-after"); + const retryMessage = retryAfter + ? `${message} (retry-after: ${retryAfter})` + : message; + return new ApiError(retryMessage, "RATE_LIMITED", 429, undefined, []); + } + case status === 503: + return new ApiError( + "Endpoint loading, retry shortly", + "SERVICE_UNAVAILABLE", + 503, + undefined, + [], + ); + case status >= 500: + return new ApiError(message, "BAD_GATEWAY", 502, undefined, []); + default: + return new ApiError(message, "UNKNOWN", status, undefined, []); + } +} + +/** + * Invokes a serving endpoint and returns the parsed JSON response. + */ +export async function invoke( + client: WorkspaceClient, + endpointName: string, + body: Record, + options?: ServingInvokeOptions, +): Promise { + const host = client.config.host; + if (!host) { + throw new Error( + "Databricks host is not configured. Set DATABRICKS_HOST or configure client.config.host.", + ); + } + + const url = buildInvocationUrl(host, endpointName, options?.servedModel); + + // Always strip `stream` from the body — the connector controls this + const { stream: _stream, ...cleanBody } = body; + + const headers = new Headers({ + "Content-Type": "application/json", + Accept: "application/json", + }); + await client.config.authenticate(headers); + + logger.debug("Invoking endpoint %s at %s", endpointName, url); + + const res = await fetch(url, { + method: "POST", + headers, + body: JSON.stringify(cleanBody), + signal: options?.signal, + }); + + if (!res.ok) { + const text = await res.text(); + throw mapUpstreamError(res.status, text, res.headers); + } + + return res.json(); +} + +/** + * Invokes a serving endpoint with streaming enabled. + * Yields parsed JSON chunks from the NDJSON SSE response. + */ +export async function* stream( + client: WorkspaceClient, + endpointName: string, + body: Record, + options?: ServingInvokeOptions, +): AsyncGenerator { + const host = client.config.host; + if (!host) { + throw new Error( + "Databricks host is not configured. Set DATABRICKS_HOST or configure client.config.host.", + ); + } + + const url = buildInvocationUrl(host, endpointName, options?.servedModel); + + // Strip any user-provided `stream` and inject `stream: true` + const { stream: _stream, ...cleanBody } = body; + const streamBody = { ...cleanBody, stream: true }; + + const headers = new Headers({ + "Content-Type": "application/json", + Accept: "text/event-stream", + }); + await client.config.authenticate(headers); + + logger.debug("Streaming from endpoint %s at %s", endpointName, url); + + const res = await fetch(url, { + method: "POST", + headers, + body: JSON.stringify(streamBody), + signal: options?.signal, + }); + + if (!res.ok) { + const text = await res.text(); + throw mapUpstreamError(res.status, text, res.headers); + } + + if (!res.body) { + throw new Error("Response body is null — streaming not supported"); + } + + const reader = res.body.getReader(); + const decoder = new TextDecoder(); + let buffer = ""; + const MAX_BUFFER_SIZE = 1024 * 1024; // 1 MB + + try { + while (true) { + if (options?.signal?.aborted) break; + + const { done, value } = await reader.read(); + if (done) break; + + buffer += decoder.decode(value, { stream: true }); + + if (buffer.length > MAX_BUFFER_SIZE) { + logger.warn( + "Stream buffer exceeded %d bytes, discarding incomplete data", + MAX_BUFFER_SIZE, + ); + buffer = ""; + } + + // Process complete lines from the buffer + const lines = buffer.split("\n"); + // Keep the last (potentially incomplete) line in the buffer + buffer = lines.pop() ?? ""; + + for (const line of lines) { + const trimmed = line.trim(); + if (!trimmed || trimmed.startsWith(":")) continue; // skip empty lines and SSE comments + if (trimmed === "data: [DONE]") return; + + if (trimmed.startsWith("data: ")) { + const jsonStr = trimmed.slice(6); + try { + yield JSON.parse(jsonStr); + } catch { + logger.warn("Failed to parse streaming chunk: %s", jsonStr); + } + } + } + } + + // Process any remaining data in the buffer + if (buffer.trim() && !options?.signal?.aborted) { + const trimmed = buffer.trim(); + if (trimmed.startsWith("data: ") && trimmed !== "data: [DONE]") { + try { + yield JSON.parse(trimmed.slice(6)); + } catch { + logger.warn("Failed to parse final streaming chunk: %s", trimmed); + } + } + } + } finally { + reader.cancel().catch(() => {}); + reader.releaseLock(); + } +} diff --git a/packages/appkit/src/connectors/serving/tests/client.test.ts b/packages/appkit/src/connectors/serving/tests/client.test.ts new file mode 100644 index 00000000..6af859ae --- /dev/null +++ b/packages/appkit/src/connectors/serving/tests/client.test.ts @@ -0,0 +1,303 @@ +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { invoke, stream } from "../client"; + +const mockAuthenticate = vi.fn(); + +function createMockClient(host = "https://test.databricks.com") { + return { + config: { + host, + authenticate: mockAuthenticate, + }, + } as any; +} + +describe("Serving Connector", () => { + beforeEach(() => { + mockAuthenticate.mockResolvedValue(undefined); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + describe("invoke", () => { + test("constructs correct URL for endpoint invocation", async () => { + const fetchSpy = vi + .spyOn(globalThis, "fetch") + .mockResolvedValue( + new Response(JSON.stringify({ result: "ok" }), { status: 200 }), + ); + + const client = createMockClient(); + await invoke(client, "my-endpoint", { messages: [] }); + + expect(fetchSpy).toHaveBeenCalledWith( + "https://test.databricks.com/serving-endpoints/my-endpoint/invocations", + expect.objectContaining({ method: "POST" }), + ); + }); + + test("constructs correct URL with servedModel override", async () => { + const fetchSpy = vi + .spyOn(globalThis, "fetch") + .mockResolvedValue( + new Response(JSON.stringify({ result: "ok" }), { status: 200 }), + ); + + const client = createMockClient(); + await invoke( + client, + "my-endpoint", + { messages: [] }, + { servedModel: "llama-v2" }, + ); + + expect(fetchSpy).toHaveBeenCalledWith( + "https://test.databricks.com/serving-endpoints/my-endpoint/served-models/llama-v2/invocations", + expect.objectContaining({ method: "POST" }), + ); + }); + + test("authenticates request headers", async () => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify({ result: "ok" }), { status: 200 }), + ); + + const client = createMockClient(); + await invoke(client, "my-endpoint", { messages: [] }); + + expect(mockAuthenticate).toHaveBeenCalledWith(expect.any(Headers)); + }); + + test("strips stream property from body", async () => { + const fetchSpy = vi + .spyOn(globalThis, "fetch") + .mockResolvedValue( + new Response(JSON.stringify({ result: "ok" }), { status: 200 }), + ); + + const client = createMockClient(); + await invoke(client, "my-endpoint", { + messages: [], + stream: true, + temperature: 0.7, + }); + + const body = JSON.parse(fetchSpy.mock.calls[0][1]?.body as string); + expect(body).toEqual({ messages: [], temperature: 0.7 }); + expect(body.stream).toBeUndefined(); + }); + + test("returns parsed JSON response", async () => { + const responseData = { choices: [{ message: { content: "Hello" } }] }; + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify(responseData), { status: 200 }), + ); + + const client = createMockClient(); + const result = await invoke(client, "my-endpoint", { messages: [] }); + + expect(result).toEqual(responseData); + }); + + test("throws ApiError on 400 response", async () => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify({ message: "Invalid params" }), { + status: 400, + }), + ); + + const client = createMockClient(); + await expect( + invoke(client, "my-endpoint", { messages: [] }), + ).rejects.toThrow("Invalid params"); + }); + + test("throws ApiError on 404 response", async () => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify({ message: "Endpoint not found" }), { + status: 404, + }), + ); + + const client = createMockClient(); + await expect( + invoke(client, "my-endpoint", { messages: [] }), + ).rejects.toThrow("Endpoint not found"); + }); + + test("maps 5xx to 502 bad gateway", async () => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify({ message: "Internal error" }), { + status: 500, + }), + ); + + const client = createMockClient(); + try { + await invoke(client, "my-endpoint", { messages: [] }); + expect.unreachable("Should have thrown"); + } catch (err: any) { + expect(err.statusCode).toBe(502); + } + }); + + test("forwards AbortSignal", async () => { + const controller = new AbortController(); + const fetchSpy = vi + .spyOn(globalThis, "fetch") + .mockResolvedValue( + new Response(JSON.stringify({ result: "ok" }), { status: 200 }), + ); + + const client = createMockClient(); + await invoke( + client, + "my-endpoint", + { messages: [] }, + { signal: controller.signal }, + ); + + expect(fetchSpy.mock.calls[0][1]?.signal).toBe(controller.signal); + }); + + test("throws when host is not configured", async () => { + const client = { + config: { + host: "", + authenticate: mockAuthenticate, + }, + } as any; + await expect( + invoke(client, "my-endpoint", { messages: [] }), + ).rejects.toThrow("Databricks host is not configured"); + }); + + test("prepends https:// to host without protocol", async () => { + const fetchSpy = vi + .spyOn(globalThis, "fetch") + .mockResolvedValue( + new Response(JSON.stringify({ result: "ok" }), { status: 200 }), + ); + + const client = createMockClient("test.databricks.com"); + await invoke(client, "my-endpoint", { messages: [] }); + + expect(fetchSpy.mock.calls[0][0]).toContain( + "https://test.databricks.com", + ); + }); + }); + + describe("stream", () => { + function createSSEResponse(chunks: string[]) { + const body = `${chunks.join("\n")}\n`; + return new Response(body, { + status: 200, + headers: { "Content-Type": "text/event-stream" }, + }); + } + + test("yields parsed NDJSON chunks", async () => { + const chunks = [ + 'data: {"choices":[{"delta":{"content":"Hello"}}]}', + 'data: {"choices":[{"delta":{"content":" world"}}]}', + "data: [DONE]", + ]; + + vi.spyOn(globalThis, "fetch").mockResolvedValue( + createSSEResponse(chunks), + ); + + const client = createMockClient(); + const results: unknown[] = []; + for await (const chunk of stream(client, "my-endpoint", { + messages: [], + })) { + results.push(chunk); + } + + expect(results).toEqual([ + { choices: [{ delta: { content: "Hello" } }] }, + { choices: [{ delta: { content: " world" } }] }, + ]); + }); + + test("injects stream: true into body", async () => { + const fetchSpy = vi + .spyOn(globalThis, "fetch") + .mockResolvedValue(createSSEResponse(["data: [DONE]"])); + + const client = createMockClient(); + // Consume the generator + for await (const _ of stream(client, "my-endpoint", { messages: [] })) { + // noop + } + + const body = JSON.parse(fetchSpy.mock.calls[0][1]?.body as string); + expect(body.stream).toBe(true); + }); + + test("strips user-provided stream and re-injects", async () => { + const fetchSpy = vi + .spyOn(globalThis, "fetch") + .mockResolvedValue(createSSEResponse(["data: [DONE]"])); + + const client = createMockClient(); + for await (const _ of stream(client, "my-endpoint", { + messages: [], + stream: false, + })) { + // noop + } + + const body = JSON.parse(fetchSpy.mock.calls[0][1]?.body as string); + expect(body.stream).toBe(true); + }); + + test("skips SSE comments and empty lines", async () => { + const chunks = [ + ": this is a comment", + "", + 'data: {"choices":[{"delta":{"content":"Hi"}}]}', + "", + "data: [DONE]", + ]; + + vi.spyOn(globalThis, "fetch").mockResolvedValue( + createSSEResponse(chunks), + ); + + const client = createMockClient(); + const results: unknown[] = []; + for await (const chunk of stream(client, "my-endpoint", { + messages: [], + })) { + results.push(chunk); + } + + expect(results).toHaveLength(1); + expect(results[0]).toEqual({ choices: [{ delta: { content: "Hi" } }] }); + }); + + test("throws on non-OK response", async () => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify({ message: "Rate limited" }), { + status: 429, + headers: { "Retry-After": "5" }, + }), + ); + + const client = createMockClient(); + try { + for await (const _ of stream(client, "my-endpoint", { messages: [] })) { + // noop + } + expect.unreachable("Should have thrown"); + } catch (err: any) { + expect(err.statusCode).toBe(429); + } + }); + }); +}); diff --git a/packages/appkit/src/connectors/serving/types.ts b/packages/appkit/src/connectors/serving/types.ts new file mode 100644 index 00000000..6dd1acba --- /dev/null +++ b/packages/appkit/src/connectors/serving/types.ts @@ -0,0 +1,4 @@ +export interface ServingInvokeOptions { + servedModel?: string; + signal?: AbortSignal; +} diff --git a/packages/appkit/src/index.ts b/packages/appkit/src/index.ts index 8db7f1d7..662a9178 100644 --- a/packages/appkit/src/index.ts +++ b/packages/appkit/src/index.ts @@ -48,7 +48,13 @@ export { } from "./errors"; // Plugin authoring export { Plugin, type ToPlugin, toPlugin } from "./plugin"; -export { analytics, files, genie, lakebase, server } from "./plugins"; +export { analytics, files, genie, lakebase, server, serving } from "./plugins"; +export type { + EndpointConfig, + ServingEndpointEntry, + ServingEndpointRegistry, + ServingFactory, +} from "./plugins/serving/types"; // Registry types and utilities for plugin manifests export type { ConfigSchema, diff --git a/packages/appkit/src/plugins/index.ts b/packages/appkit/src/plugins/index.ts index 7caa040f..4d58082f 100644 --- a/packages/appkit/src/plugins/index.ts +++ b/packages/appkit/src/plugins/index.ts @@ -3,3 +3,4 @@ export * from "./files"; export * from "./genie"; export * from "./lakebase"; export * from "./server"; +export * from "./serving"; diff --git a/packages/appkit/src/plugins/serving/defaults.ts b/packages/appkit/src/plugins/serving/defaults.ts new file mode 100644 index 00000000..1fea64c2 --- /dev/null +++ b/packages/appkit/src/plugins/serving/defaults.ts @@ -0,0 +1,26 @@ +import type { StreamExecutionSettings } from "shared"; + +export const servingInvokeDefaults = { + cache: { + enabled: false, + }, + retry: { + enabled: false, + }, + timeout: 120_000, +}; + +export const servingStreamDefaults: StreamExecutionSettings = { + default: { + cache: { + enabled: false, + }, + retry: { + enabled: false, + }, + timeout: 120_000, + }, + stream: { + bufferSize: 200, + }, +}; diff --git a/packages/appkit/src/plugins/serving/index.ts b/packages/appkit/src/plugins/serving/index.ts new file mode 100644 index 00000000..85caf33b --- /dev/null +++ b/packages/appkit/src/plugins/serving/index.ts @@ -0,0 +1,2 @@ +export * from "./serving"; +export * from "./types"; diff --git a/packages/appkit/src/plugins/serving/manifest.json b/packages/appkit/src/plugins/serving/manifest.json new file mode 100644 index 00000000..9ac0845f --- /dev/null +++ b/packages/appkit/src/plugins/serving/manifest.json @@ -0,0 +1,54 @@ +{ + "$schema": "https://databricks.github.io/appkit/schemas/plugin-manifest.schema.json", + "name": "serving", + "displayName": "Model Serving Plugin", + "description": "Authenticated proxy to Databricks Model Serving endpoints", + "resources": { + "required": [ + { + "type": "serving_endpoint", + "alias": "Serving Endpoint", + "resourceKey": "serving-endpoint", + "description": "Model Serving endpoint for inference", + "permission": "CAN_QUERY", + "fields": { + "name": { + "env": "DATABRICKS_SERVING_ENDPOINT", + "description": "Serving endpoint name" + } + } + } + ], + "optional": [] + }, + "config": { + "schema": { + "type": "object", + "properties": { + "endpoints": { + "type": "object", + "description": "Map of alias names to endpoint configurations", + "additionalProperties": { + "type": "object", + "properties": { + "env": { + "type": "string", + "description": "Environment variable holding the endpoint name" + }, + "servedModel": { + "type": "string", + "description": "Target a specific served model (bypasses traffic routing)" + } + }, + "required": ["env"] + } + }, + "timeout": { + "type": "number", + "default": 120000, + "description": "Request timeout in ms. Default: 120000 (2 min)" + } + } + } + } +} diff --git a/packages/appkit/src/plugins/serving/schema-filter.ts b/packages/appkit/src/plugins/serving/schema-filter.ts new file mode 100644 index 00000000..6e52294a --- /dev/null +++ b/packages/appkit/src/plugins/serving/schema-filter.ts @@ -0,0 +1,127 @@ +import fs from "node:fs/promises"; +import { createLogger } from "../../logging/logger"; + +const CACHE_VERSION = "1"; + +interface ServingCacheEntry { + hash: string; + requestType: string; + responseType: string; + chunkType: string | null; +} + +interface ServingCache { + version: string; + endpoints: Record; +} + +const logger = createLogger("serving:schema-filter"); + +function isValidCache(data: unknown): data is ServingCache { + return ( + typeof data === "object" && + data !== null && + "version" in data && + (data as ServingCache).version === CACHE_VERSION && + "endpoints" in data && + typeof (data as ServingCache).endpoints === "object" + ); +} + +/** + * Loads endpoint schemas from the type generation cache file. + * Returns a map of alias → allowed parameter keys. + */ +export async function loadEndpointSchemas( + cacheFile: string, +): Promise>> { + const allowlists = new Map>(); + + try { + const raw = await fs.readFile(cacheFile, "utf8"); + const parsed: unknown = JSON.parse(raw); + if (!isValidCache(parsed)) { + logger.warn("Serving types cache has invalid structure, skipping"); + return allowlists; + } + const cache = parsed; + + for (const [alias, entry] of Object.entries(cache.endpoints)) { + // Extract property keys from the requestType string + // The requestType is a TypeScript object type like "{ messages: ...; temperature: ...; }" + const keys = extractPropertyKeys(entry.requestType); + if (keys.size > 0) { + allowlists.set(alias, keys); + } + } + } catch (err) { + if ((err as NodeJS.ErrnoException).code !== "ENOENT") { + logger.warn( + "Failed to load serving types cache: %s", + (err as Error).message, + ); + } + // No cache → no filtering, passthrough mode + } + + return allowlists; +} + +/** + * Extracts top-level property keys from a TypeScript object type string. + * Matches patterns like `key:` or `key?:` at the first nesting level. + */ +function extractPropertyKeys(typeStr: string): Set { + const keys = new Set(); + // Match property names at the top level of the object type + // Looking for patterns: ` propertyName:` or ` propertyName?:` + const propRegex = /^\s{2}(?:\/\*\*[^*]*\*\/\s*)?(\w+)\??:/gm; + for ( + let match = propRegex.exec(typeStr); + match !== null; + match = propRegex.exec(typeStr) + ) { + keys.add(match[1]); + } + return keys; +} + +/** + * Filters a request body against the allowed keys for an endpoint alias. + * Returns the filtered body and logs a warning for stripped params. + * + * If no allowlist exists for the alias, returns the body unchanged (passthrough). + */ +export function filterRequestBody( + body: Record, + allowlists: Map>, + alias: string, + filterMode: "strip" | "reject" = "strip", +): Record { + const allowed = allowlists.get(alias); + if (!allowed) return body; + + const stripped: string[] = []; + const filtered: Record = {}; + + for (const [key, value] of Object.entries(body)) { + if (allowed.has(key)) { + filtered[key] = value; + } else { + stripped.push(key); + } + } + + if (stripped.length > 0) { + if (filterMode === "reject") { + throw new Error(`Unknown request parameters: ${stripped.join(", ")}`); + } + logger.warn( + "Stripped unknown params from '%s': %s", + alias, + stripped.join(", "), + ); + } + + return filtered; +} diff --git a/packages/appkit/src/plugins/serving/serving.ts b/packages/appkit/src/plugins/serving/serving.ts new file mode 100644 index 00000000..e868cc02 --- /dev/null +++ b/packages/appkit/src/plugins/serving/serving.ts @@ -0,0 +1,303 @@ +import { randomUUID } from "node:crypto"; +import path from "node:path"; +import type express from "express"; +import type { IAppRouter, StreamExecutionSettings } from "shared"; +import * as servingConnector from "../../connectors/serving/client"; +import { getWorkspaceClient } from "../../context"; +import { createLogger } from "../../logging"; +import { Plugin, toPlugin } from "../../plugin"; +import type { PluginManifest, ResourceRequirement } from "../../registry"; +import { ResourceType } from "../../registry"; +import { servingInvokeDefaults, servingStreamDefaults } from "./defaults"; +import manifest from "./manifest.json"; +import { filterRequestBody, loadEndpointSchemas } from "./schema-filter"; +import type { EndpointConfig, IServingConfig, ServingFactory } from "./types"; + +const logger = createLogger("serving"); + +class EndpointNotFoundError extends Error { + constructor(alias: string) { + super(`Unknown endpoint alias: ${alias}`); + } +} + +class EndpointNotConfiguredError extends Error { + constructor(alias: string, envVar: string) { + super( + `Endpoint '${alias}' is not configured: env var '${envVar}' is not set`, + ); + } +} + +interface ResolvedEndpoint { + name: string; + servedModel?: string; +} + +export class ServingPlugin extends Plugin { + static manifest = manifest as PluginManifest<"serving">; + + protected static description = + "Authenticated proxy to Databricks Model Serving endpoints"; + protected declare config: IServingConfig; + + private readonly endpoints: Record; + private readonly isNamedMode: boolean; + private schemaAllowlists = new Map>(); + + constructor(config: IServingConfig) { + super(config); + this.config = config; + + if (config.endpoints) { + this.endpoints = config.endpoints; + this.isNamedMode = true; + } else { + this.endpoints = { + default: { env: "DATABRICKS_SERVING_ENDPOINT" }, + }; + this.isNamedMode = false; + } + } + + async setup(): Promise { + const cacheFile = path.join( + process.cwd(), + "node_modules", + ".databricks", + "appkit", + ".appkit-serving-types-cache.json", + ); + this.schemaAllowlists = await loadEndpointSchemas(cacheFile); + if (this.schemaAllowlists.size > 0) { + logger.debug( + "Loaded schema allowlists for %d endpoint(s)", + this.schemaAllowlists.size, + ); + } + } + + static getResourceRequirements( + config: IServingConfig, + ): ResourceRequirement[] { + const endpoints = config.endpoints ?? { + default: { env: "DATABRICKS_SERVING_ENDPOINT" }, + }; + + return Object.entries(endpoints).map(([alias, endpointConfig]) => ({ + type: ResourceType.SERVING_ENDPOINT, + alias: `serving-${alias}`, + resourceKey: `serving-${alias}`, + description: `Model Serving endpoint for "${alias}" inference`, + permission: "CAN_QUERY" as const, + fields: { + name: { + env: endpointConfig.env, + description: `Serving endpoint name for "${alias}"`, + }, + }, + required: true, + })); + } + + private resolveAndFilter( + alias: string, + body: Record, + ): { endpoint: ResolvedEndpoint; filteredBody: Record } { + const config = this.endpoints[alias]; + if (!config) { + throw new EndpointNotFoundError(alias); + } + + const name = process.env[config.env]; + if (!name) { + throw new EndpointNotConfiguredError(alias, config.env); + } + + const endpoint: ResolvedEndpoint = { + name, + servedModel: config.servedModel, + }; + const filteredBody = filterRequestBody( + body, + this.schemaAllowlists, + alias, + this.config.filterMode, + ); + return { endpoint, filteredBody }; + } + + injectRoutes(router: IAppRouter) { + if (this.isNamedMode) { + this.route(router, { + name: "invoke", + method: "post", + path: "/:alias/invoke", + handler: async (req: express.Request, res: express.Response) => { + await this.asUser(req)._handleInvoke(req, res); + }, + }); + + this.route(router, { + name: "stream", + method: "post", + path: "/:alias/stream", + handler: async (req: express.Request, res: express.Response) => { + await this.asUser(req)._handleStream(req, res); + }, + }); + } else { + this.route(router, { + name: "invoke", + method: "post", + path: "/invoke", + handler: async (req: express.Request, res: express.Response) => { + req.params.alias = "default"; + await this.asUser(req)._handleInvoke(req, res); + }, + }); + + this.route(router, { + name: "stream", + method: "post", + path: "/stream", + handler: async (req: express.Request, res: express.Response) => { + req.params.alias = "default"; + await this.asUser(req)._handleStream(req, res); + }, + }); + } + } + + async _handleInvoke( + req: express.Request, + res: express.Response, + ): Promise { + const { alias } = req.params; + const rawBody = req.body as Record; + + try { + const result = await this.invoke(alias, rawBody); + if (result === undefined) { + res.status(502).json({ error: "Invocation returned no result" }); + return; + } + res.json(result); + } catch (err) { + const message = err instanceof Error ? err.message : "Invocation failed"; + if (err instanceof EndpointNotFoundError) { + res.status(404).json({ error: message }); + } else if ( + err instanceof EndpointNotConfiguredError || + message.startsWith("Unknown request parameters:") + ) { + res.status(400).json({ error: message }); + } else { + res.status(502).json({ error: message }); + } + } + } + + async _handleStream( + req: express.Request, + res: express.Response, + ): Promise { + const { alias } = req.params; + const rawBody = req.body as Record; + + let endpoint: ResolvedEndpoint; + let filteredBody: Record; + try { + ({ endpoint, filteredBody } = this.resolveAndFilter(alias, rawBody)); + } catch (err) { + const message = err instanceof Error ? err.message : "Invalid request"; + const status = err instanceof EndpointNotFoundError ? 404 : 400; + res.status(status).json({ error: message }); + return; + } + + const timeout = this.config.timeout ?? 120_000; + const requestId = + (typeof req.query.requestId === "string" && req.query.requestId) || + randomUUID(); + + const streamSettings: StreamExecutionSettings = { + ...servingStreamDefaults, + default: { + ...servingStreamDefaults.default, + timeout, + }, + stream: { + ...servingStreamDefaults.stream, + streamId: requestId, + }, + }; + + const workspaceClient = getWorkspaceClient(); + if (!workspaceClient.config.host) { + res.status(500).json({ error: "Databricks host not configured" }); + return; + } + + await this.executeStream( + res, + () => + servingConnector.stream(workspaceClient, endpoint.name, filteredBody, { + servedModel: endpoint.servedModel, + }), + streamSettings, + ); + } + + async invoke(alias: string, body: Record): Promise { + const { endpoint, filteredBody } = this.resolveAndFilter(alias, body); + const workspaceClient = getWorkspaceClient(); + const timeout = this.config.timeout ?? 120_000; + + return this.execute( + () => + servingConnector.invoke(workspaceClient, endpoint.name, filteredBody, { + servedModel: endpoint.servedModel, + }), + { + default: { + ...servingInvokeDefaults, + timeout, + }, + }, + ); + } + + async *stream( + alias: string, + body: Record, + ): AsyncGenerator { + const { endpoint, filteredBody } = this.resolveAndFilter(alias, body); + const workspaceClient = getWorkspaceClient(); + + yield* servingConnector.stream( + workspaceClient, + endpoint.name, + filteredBody, + { servedModel: endpoint.servedModel }, + ); + } + + async shutdown(): Promise { + this.streamManager.abortAll(); + } + + exports(): ServingFactory { + return ((alias?: string) => ({ + invoke: (body: Record) => + this.invoke(alias ?? "default", body), + stream: (body: Record) => + this.stream(alias ?? "default", body), + })) as ServingFactory; + } +} + +/** + * @internal + */ +export const serving = toPlugin(ServingPlugin); diff --git a/packages/appkit/src/plugins/serving/tests/schema-filter.test.ts b/packages/appkit/src/plugins/serving/tests/schema-filter.test.ts new file mode 100644 index 00000000..948b47f9 --- /dev/null +++ b/packages/appkit/src/plugins/serving/tests/schema-filter.test.ts @@ -0,0 +1,141 @@ +import { describe, expect, test, vi } from "vitest"; +import { filterRequestBody, loadEndpointSchemas } from "../schema-filter"; + +vi.mock("node:fs/promises", () => ({ + default: { + readFile: vi.fn(), + }, +})); + +describe("schema-filter", () => { + describe("filterRequestBody", () => { + test("strips unknown keys when allowlist exists", () => { + const allowlists = new Map([ + ["default", new Set(["messages", "temperature"])], + ]); + + const result = filterRequestBody( + { messages: [], temperature: 0.7, unknown_param: true }, + allowlists, + "default", + ); + + expect(result).toEqual({ messages: [], temperature: 0.7 }); + }); + + test("preserves all keys when no allowlist for alias", () => { + const allowlists = new Map>(); + + const body = { messages: [], custom: "value" }; + const result = filterRequestBody(body, allowlists, "default"); + + expect(result).toBe(body); // Same reference, no filtering + }); + + test("returns empty object when all keys are unknown", () => { + const allowlists = new Map([["default", new Set(["messages"])]]); + + const result = filterRequestBody( + { bad1: 1, bad2: 2 }, + allowlists, + "default", + ); + + expect(result).toEqual({}); + }); + + test("returns full body when all keys are allowed", () => { + const allowlists = new Map([["default", new Set(["a", "b", "c"])]]); + + const result = filterRequestBody( + { a: 1, b: 2, c: 3 }, + allowlists, + "default", + ); + + expect(result).toEqual({ a: 1, b: 2, c: 3 }); + }); + + test("throws in reject mode when unknown keys are present", () => { + const allowlists = new Map([["default", new Set(["messages"])]]); + + expect(() => + filterRequestBody( + { messages: [], unknown_param: true }, + allowlists, + "default", + "reject", + ), + ).toThrow("Unknown request parameters: unknown_param"); + }); + + test("does not throw in reject mode when all keys are allowed", () => { + const allowlists = new Map([ + ["default", new Set(["messages", "temperature"])], + ]); + + const result = filterRequestBody( + { messages: [], temperature: 0.7 }, + allowlists, + "default", + "reject", + ); + + expect(result).toEqual({ messages: [], temperature: 0.7 }); + }); + + test("strips in default mode (strip)", () => { + const allowlists = new Map([["default", new Set(["messages"])]]); + + const result = filterRequestBody( + { messages: [], extra: true }, + allowlists, + "default", + "strip", + ); + + expect(result).toEqual({ messages: [] }); + }); + }); + + describe("loadEndpointSchemas", () => { + test("returns empty map when cache file does not exist", async () => { + const fs = (await import("node:fs/promises")).default; + vi.mocked(fs.readFile).mockRejectedValue( + Object.assign(new Error("ENOENT"), { code: "ENOENT" }), + ); + + const result = await loadEndpointSchemas("/nonexistent/path"); + expect(result.size).toBe(0); + }); + + test("extracts property keys from cached types", async () => { + const fs = (await import("node:fs/promises")).default; + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + version: "1", + endpoints: { + default: { + hash: "abc", + requestType: `{ + messages: string[]; + temperature?: number | null; + max_tokens: number; +}`, + responseType: "{}", + chunkType: null, + }, + }, + }), + ); + + const result = await loadEndpointSchemas("/some/path"); + expect(result.size).toBe(1); + const keys = result.get("default"); + expect(keys).toBeDefined(); + expect(keys?.has("messages")).toBe(true); + expect(keys?.has("temperature")).toBe(true); + expect(keys?.has("max_tokens")).toBe(true); + }); + }); +}); diff --git a/packages/appkit/src/plugins/serving/tests/serving.test.ts b/packages/appkit/src/plugins/serving/tests/serving.test.ts new file mode 100644 index 00000000..1a953b77 --- /dev/null +++ b/packages/appkit/src/plugins/serving/tests/serving.test.ts @@ -0,0 +1,339 @@ +import { + createMockRequest, + createMockResponse, + createMockRouter, + mockServiceContext, + setupDatabricksEnv, +} from "@tools/test-helpers"; +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { ServiceContext } from "../../../context/service-context"; +import { ServingPlugin, serving } from "../serving"; +import type { IServingConfig } from "../types"; + +// Mock CacheManager singleton +const { mockCacheInstance } = vi.hoisted(() => { + const instance = { + get: vi.fn(), + set: vi.fn(), + delete: vi.fn(), + getOrExecute: vi + .fn() + .mockImplementation( + async (_key: unknown[], fn: () => Promise) => { + return await fn(); + }, + ), + generateKey: vi.fn((...args: unknown[]) => JSON.stringify(args)), + }; + return { mockCacheInstance: instance }; +}); + +vi.mock("../../../cache", () => ({ + CacheManager: { + getInstanceSync: vi.fn(() => mockCacheInstance), + }, +})); + +// Mock the serving connector +const mockInvoke = vi.fn(); +const mockStream = vi.fn(); + +vi.mock("../../../connectors/serving/client", () => ({ + invoke: (...args: any[]) => mockInvoke(...args), + stream: (...args: any[]) => mockStream(...args), +})); + +describe("Serving Plugin", () => { + let serviceContextMock: Awaited>; + + beforeEach(async () => { + setupDatabricksEnv(); + process.env.DATABRICKS_SERVING_ENDPOINT = "test-endpoint"; + ServiceContext.reset(); + + serviceContextMock = await mockServiceContext(); + }); + + afterEach(() => { + serviceContextMock?.restore(); + delete process.env.DATABRICKS_SERVING_ENDPOINT; + vi.restoreAllMocks(); + }); + + test("serving factory should have correct name", () => { + const pluginData = serving(); + expect(pluginData.name).toBe("serving"); + }); + + test("serving factory with config should have correct name", () => { + const pluginData = serving({ + endpoints: { llm: { env: "DATABRICKS_SERVING_ENDPOINT" } }, + }); + expect(pluginData.name).toBe("serving"); + }); + + describe("default mode", () => { + test("reads DATABRICKS_SERVING_ENDPOINT", () => { + const plugin = new ServingPlugin({}); + const api = (plugin.exports() as any)(); + expect(api.invoke).toBeDefined(); + expect(api.stream).toBeDefined(); + }); + + test("injects /invoke and /stream routes", () => { + const plugin = new ServingPlugin({}); + const { router, handlers } = createMockRouter(); + + plugin.injectRoutes(router); + + expect(handlers["POST:/invoke"]).toBeDefined(); + expect(handlers["POST:/stream"]).toBeDefined(); + }); + + test("exports returns a factory that provides invoke and stream", () => { + const plugin = new ServingPlugin({}); + const factory = plugin.exports() as any; + const api = factory(); + + expect(typeof api.invoke).toBe("function"); + expect(typeof api.stream).toBe("function"); + }); + }); + + describe("named mode", () => { + const namedConfig: IServingConfig = { + endpoints: { + llm: { env: "DATABRICKS_SERVING_ENDPOINT" }, + embedder: { env: "DATABRICKS_SERVING_ENDPOINT_EMBEDDING" }, + }, + }; + + test("injects /:alias/invoke and /:alias/stream routes", () => { + const plugin = new ServingPlugin(namedConfig); + const { router, handlers } = createMockRouter(); + + plugin.injectRoutes(router); + + expect(handlers["POST:/:alias/invoke"]).toBeDefined(); + expect(handlers["POST:/:alias/stream"]).toBeDefined(); + }); + + test("exports factory returns invoke and stream for named aliases", () => { + const plugin = new ServingPlugin(namedConfig); + const factory = plugin.exports() as any; + + expect(typeof factory("llm").invoke).toBe("function"); + expect(typeof factory("llm").stream).toBe("function"); + expect(typeof factory("embedder").invoke).toBe("function"); + expect(typeof factory("embedder").stream).toBe("function"); + }); + }); + + describe("route handlers", () => { + test("_handleInvoke returns 404 for unknown alias", async () => { + const plugin = new ServingPlugin({ + endpoints: { llm: { env: "DATABRICKS_SERVING_ENDPOINT" } }, + }); + + const req = createMockRequest({ + params: { alias: "unknown" }, + body: { messages: [] }, + }); + const res = createMockResponse(); + + await plugin._handleInvoke(req as any, res as any); + + expect(res.status).toHaveBeenCalledWith(404); + expect(res.json).toHaveBeenCalledWith({ + error: "Unknown endpoint alias: unknown", + }); + }); + + test("_handleInvoke calls connector with correct endpoint", async () => { + mockInvoke.mockResolvedValue({ choices: [] }); + + const plugin = new ServingPlugin({}); + const req = createMockRequest({ + params: { alias: "default" }, + body: { messages: [{ role: "user", content: "Hello" }] }, + }); + const res = createMockResponse(); + + await plugin._handleInvoke(req as any, res as any); + + expect(mockInvoke).toHaveBeenCalledWith( + expect.anything(), + "test-endpoint", + { messages: [{ role: "user", content: "Hello" }] }, + { servedModel: undefined }, + ); + expect(res.json).toHaveBeenCalledWith({ choices: [] }); + }); + + test("_handleInvoke returns 400 with descriptive message when env var is not set", async () => { + delete process.env.DATABRICKS_SERVING_ENDPOINT; + + const plugin = new ServingPlugin({}); + const req = createMockRequest({ + params: { alias: "default" }, + body: { messages: [] }, + }); + const res = createMockResponse(); + + await plugin._handleInvoke(req as any, res as any); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.json).toHaveBeenCalledWith({ + error: + "Endpoint 'default' is not configured: env var 'DATABRICKS_SERVING_ENDPOINT' is not set", + }); + }); + + test("_handleInvoke does not throw when connector fails", async () => { + mockInvoke.mockRejectedValue(new Error("Connection refused")); + + const plugin = new ServingPlugin({}); + const req = createMockRequest({ + params: { alias: "default" }, + body: { messages: [] }, + }); + const res = createMockResponse(); + + // Should not throw — execute() handles the error internally + await expect( + plugin._handleInvoke(req as any, res as any), + ).resolves.not.toThrow(); + }); + + test("_handleStream returns 404 for unknown alias", async () => { + const plugin = new ServingPlugin({ + endpoints: { llm: { env: "DATABRICKS_SERVING_ENDPOINT" } }, + }); + + const req = createMockRequest({ + params: { alias: "unknown" }, + body: { messages: [] }, + query: {}, + }); + const res = createMockResponse(); + + await plugin._handleStream(req as any, res as any); + + expect(res.status).toHaveBeenCalledWith(404); + expect(res.json).toHaveBeenCalledWith({ + error: "Unknown endpoint alias: unknown", + }); + }); + + test("_handleStream returns 400 when env var is not set", async () => { + delete process.env.DATABRICKS_SERVING_ENDPOINT; + + const plugin = new ServingPlugin({}); + const req = createMockRequest({ + params: { alias: "default" }, + body: { messages: [] }, + query: {}, + }); + const res = createMockResponse(); + + await plugin._handleStream(req as any, res as any); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.json).toHaveBeenCalledWith({ + error: + "Endpoint 'default' is not configured: env var 'DATABRICKS_SERVING_ENDPOINT' is not set", + }); + }); + }); + + describe("getResourceRequirements", () => { + test("generates requirements for default mode", () => { + const reqs = ServingPlugin.getResourceRequirements({}); + expect(reqs).toHaveLength(1); + expect(reqs[0]).toMatchObject({ + type: "serving_endpoint", + alias: "serving-default", + permission: "CAN_QUERY", + fields: { + name: { + env: "DATABRICKS_SERVING_ENDPOINT", + }, + }, + }); + }); + + test("generates requirements for named mode", () => { + const reqs = ServingPlugin.getResourceRequirements({ + endpoints: { + llm: { env: "LLM_ENDPOINT" }, + embedder: { env: "EMBED_ENDPOINT" }, + }, + }); + expect(reqs).toHaveLength(2); + expect(reqs[0].fields.name.env).toBe("LLM_ENDPOINT"); + expect(reqs[1].fields.name.env).toBe("EMBED_ENDPOINT"); + }); + }); + + describe("programmatic API", () => { + test("invoke calls connector correctly", async () => { + mockInvoke.mockResolvedValue({ + choices: [{ message: { content: "Hi" } }], + }); + + const plugin = new ServingPlugin({}); + const result = await plugin.invoke("default", { messages: [] }); + + expect(mockInvoke).toHaveBeenCalledWith( + expect.anything(), + "test-endpoint", + { messages: [] }, + { servedModel: undefined }, + ); + expect(result).toEqual({ choices: [{ message: { content: "Hi" } }] }); + }); + + test("invoke throws for unknown alias", async () => { + const plugin = new ServingPlugin({ + endpoints: { llm: { env: "DATABRICKS_SERVING_ENDPOINT" } }, + }); + + await expect(plugin.invoke("unknown", { messages: [] })).rejects.toThrow( + "Unknown endpoint alias: unknown", + ); + }); + + test("stream yields chunks from connector", async () => { + const chunks = [ + { choices: [{ delta: { content: "Hello" } }] }, + { choices: [{ delta: { content: " world" } }] }, + ]; + + mockStream.mockImplementation(async function* () { + for (const chunk of chunks) { + yield chunk; + } + }); + + const plugin = new ServingPlugin({}); + const results: unknown[] = []; + for await (const chunk of plugin.stream("default", { messages: [] })) { + results.push(chunk); + } + + expect(results).toEqual(chunks); + }); + }); + + describe("shutdown", () => { + test("calls streamManager.abortAll", async () => { + const plugin = new ServingPlugin({}); + // Accessing the protected streamManager through the plugin + const abortSpy = vi.spyOn((plugin as any).streamManager, "abortAll"); + + await plugin.shutdown(); + + expect(abortSpy).toHaveBeenCalled(); + }); + }); +}); diff --git a/packages/appkit/src/plugins/serving/types.ts b/packages/appkit/src/plugins/serving/types.ts new file mode 100644 index 00000000..9a2dd230 --- /dev/null +++ b/packages/appkit/src/plugins/serving/types.ts @@ -0,0 +1,67 @@ +import type { BasePluginConfig } from "shared"; + +export interface EndpointConfig { + /** Environment variable holding the endpoint name. */ + env: string; + /** Target a specific served model (bypasses traffic routing). */ + servedModel?: string; +} + +export interface IServingConfig extends BasePluginConfig { + /** Map of alias → endpoint config. Defaults to { default: { env: "DATABRICKS_SERVING_ENDPOINT" } } if omitted. */ + endpoints?: Record; + /** Request timeout in ms. Default: 120000 (2 min) */ + timeout?: number; + /** How to handle unknown request parameters. 'strip' silently removes them (default). 'reject' returns 400. */ + filterMode?: "strip" | "reject"; +} + +/** + * Registry interface for serving endpoint type generation. + * Empty by default — augmented by the Vite type generator's `.d.ts` output via module augmentation. + * When populated, provides autocomplete for alias names and typed request/response/chunk per endpoint. + */ +// biome-ignore lint/suspicious/noEmptyInterface: intentionally empty — populated via module augmentation +export interface ServingEndpointRegistry {} + +/** Shape of a single registry entry. */ +export interface ServingEndpointEntry { + request: Record; + response: unknown; + chunk: unknown; +} + +/** Typed invoke/stream methods for a serving endpoint. */ +export interface ServingEndpointMethods< + TRequest extends Record = Record, + TResponse = unknown, + TChunk = unknown, +> { + invoke: (body: TRequest) => Promise; + stream: (body: TRequest) => AsyncGenerator; +} + +/** + * Factory function returned by `AppKit.serving`. + * + * This is a conditional type that adapts based on whether `ServingEndpointRegistry` + * has been populated via module augmentation (generated by `appKitServingTypesPlugin()`): + * + * - **Registry empty (default):** `(alias?: string) => ServingEndpointMethods` — + * accepts any alias string with untyped request/response/chunk. + * - **Registry populated:** `(alias: K) => ServingEndpointMethods<...>` — + * restricts `alias` to known endpoint keys and infers typed request/response/chunk + * from the registry entry. + * + * Run `appKitServingTypesPlugin()` in your Vite config to generate the registry + * augmentation and enable full type safety. + */ +export type ServingFactory = keyof ServingEndpointRegistry extends never + ? (alias?: string) => ServingEndpointMethods + : ( + alias: K, + ) => ServingEndpointMethods< + ServingEndpointRegistry[K]["request"], + ServingEndpointRegistry[K]["response"], + ServingEndpointRegistry[K]["chunk"] + >; diff --git a/packages/appkit/src/stream/stream-manager.ts b/packages/appkit/src/stream/stream-manager.ts index 41764772..8b511fac 100644 --- a/packages/appkit/src/stream/stream-manager.ts +++ b/packages/appkit/src/stream/stream-manager.ts @@ -374,6 +374,14 @@ export class StreamManager { if (error.name === "AbortError") { return SSEErrorCode.STREAM_ABORTED; } + + // Detect upstream API errors (e.g., from Databricks SDK ApiError) + if ( + "statusCode" in error && + typeof (error as any).statusCode === "number" + ) { + return SSEErrorCode.UPSTREAM_ERROR; + } } return SSEErrorCode.INTERNAL_ERROR; diff --git a/packages/appkit/src/stream/types.ts b/packages/appkit/src/stream/types.ts index 0fd862ba..3841bfd1 100644 --- a/packages/appkit/src/stream/types.ts +++ b/packages/appkit/src/stream/types.ts @@ -16,6 +16,7 @@ export const SSEErrorCode = { INVALID_REQUEST: "INVALID_REQUEST", STREAM_ABORTED: "STREAM_ABORTED", STREAM_EVICTED: "STREAM_EVICTED", + UPSTREAM_ERROR: "UPSTREAM_ERROR", } as const satisfies Record; export type SSEErrorCode = (typeof SSEErrorCode)[keyof typeof SSEErrorCode]; From 218f9b58c373794fdb024d8d48cedc2c723c943d Mon Sep 17 00:00:00 2001 From: Pawel Kosiec Date: Fri, 3 Apr 2026 14:20:34 +0200 Subject: [PATCH 02/13] fix: pass abort signal to serving connector in stream handler The serving plugin was not forwarding the abort signal to the serving connector, unlike the genie plugin. Without the signal, the connector's fetch request cannot be cancelled and the abort-check loop never triggers. Signed-off-by: Pawel Kosiec --- packages/appkit/src/plugins/serving/serving.ts | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/packages/appkit/src/plugins/serving/serving.ts b/packages/appkit/src/plugins/serving/serving.ts index e868cc02..e3547bcf 100644 --- a/packages/appkit/src/plugins/serving/serving.ts +++ b/packages/appkit/src/plugins/serving/serving.ts @@ -241,9 +241,10 @@ export class ServingPlugin extends Plugin { await this.executeStream( res, - () => + (signal) => servingConnector.stream(workspaceClient, endpoint.name, filteredBody, { servedModel: endpoint.servedModel, + signal, }), streamSettings, ); From 7ba88a217304a46bc0d5250d52748192b7511aa4 Mon Sep 17 00:00:00 2001 From: Pawel Kosiec Date: Fri, 3 Apr 2026 12:05:08 +0200 Subject: [PATCH 03/13] feat: add serving type generator, Vite plugin, and UI hooks Add Vite plugin that auto-generates TypeScript types from serving endpoint OpenAPI schemas. Includes AST-based server file extraction (@ast-grep/napi), schema-to-TypeScript conversion, and caching. Also adds useServingInvoke and useServingStream React hooks in appkit-ui with full type-safe registry support. Signed-off-by: Pawel Kosiec --- .gitignore | 3 + .../Function.appKitServingTypesPlugin.md | 24 ++ .../Function.extractServingEndpoints.md | 24 ++ .../api/appkit/Function.findServerFile.md | 19 ++ docs/docs/api/appkit/index.md | 3 + docs/docs/api/appkit/typedoc-sidebar.ts | 15 + .../__tests__/use-serving-invoke.test.ts | 117 ++++++++ .../__tests__/use-serving-stream.test.ts | 271 +++++++++++++++++ packages/appkit-ui/src/react/hooks/index.ts | 15 + packages/appkit-ui/src/react/hooks/types.ts | 51 ++++ .../src/react/hooks/use-serving-invoke.ts | 103 +++++++ .../src/react/hooks/use-serving-stream.ts | 123 ++++++++ packages/appkit/package.json | 1 + packages/appkit/src/index.ts | 6 +- .../src/plugins/serving/schema-filter.ts | 18 +- .../src/type-generator/serving/cache.ts | 55 ++++ .../src/type-generator/serving/converter.ts | 149 ++++++++++ .../src/type-generator/serving/fetcher.ts | 158 ++++++++++ .../src/type-generator/serving/generator.ts | 266 +++++++++++++++++ .../serving/server-file-extractor.ts | 221 ++++++++++++++ .../serving/tests/cache.test.ts | 107 +++++++ .../serving/tests/converter.test.ts | 278 ++++++++++++++++++ .../serving/tests/fetcher.test.ts | 209 +++++++++++++ .../serving/tests/generator.test.ts | 215 ++++++++++++++ .../tests/server-file-extractor.test.ts | 213 ++++++++++++++ .../serving/tests/vite-plugin.test.ts | 186 ++++++++++++ .../src/type-generator/serving/vite-plugin.ts | 109 +++++++ pnpm-lock.yaml | 3 + 28 files changed, 2947 insertions(+), 15 deletions(-) create mode 100644 docs/docs/api/appkit/Function.appKitServingTypesPlugin.md create mode 100644 docs/docs/api/appkit/Function.extractServingEndpoints.md create mode 100644 docs/docs/api/appkit/Function.findServerFile.md create mode 100644 packages/appkit-ui/src/react/hooks/__tests__/use-serving-invoke.test.ts create mode 100644 packages/appkit-ui/src/react/hooks/__tests__/use-serving-stream.test.ts create mode 100644 packages/appkit-ui/src/react/hooks/use-serving-invoke.ts create mode 100644 packages/appkit-ui/src/react/hooks/use-serving-stream.ts create mode 100644 packages/appkit/src/type-generator/serving/cache.ts create mode 100644 packages/appkit/src/type-generator/serving/converter.ts create mode 100644 packages/appkit/src/type-generator/serving/fetcher.ts create mode 100644 packages/appkit/src/type-generator/serving/generator.ts create mode 100644 packages/appkit/src/type-generator/serving/server-file-extractor.ts create mode 100644 packages/appkit/src/type-generator/serving/tests/cache.test.ts create mode 100644 packages/appkit/src/type-generator/serving/tests/converter.test.ts create mode 100644 packages/appkit/src/type-generator/serving/tests/fetcher.test.ts create mode 100644 packages/appkit/src/type-generator/serving/tests/generator.test.ts create mode 100644 packages/appkit/src/type-generator/serving/tests/server-file-extractor.test.ts create mode 100644 packages/appkit/src/type-generator/serving/tests/vite-plugin.test.ts create mode 100644 packages/appkit/src/type-generator/serving/vite-plugin.ts diff --git a/.gitignore b/.gitignore index 3b6cc969..4c51d5b1 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,6 @@ coverage *.tsbuildinfo .turbo + +# AppKit type generator caches +.databricks diff --git a/docs/docs/api/appkit/Function.appKitServingTypesPlugin.md b/docs/docs/api/appkit/Function.appKitServingTypesPlugin.md new file mode 100644 index 00000000..bc28660a --- /dev/null +++ b/docs/docs/api/appkit/Function.appKitServingTypesPlugin.md @@ -0,0 +1,24 @@ +# Function: appKitServingTypesPlugin() + +```ts +function appKitServingTypesPlugin(options?: AppKitServingTypesPluginOptions): Plugin$1; +``` + +Vite plugin to generate TypeScript types for AppKit serving endpoints. +Fetches OpenAPI schemas from Databricks and generates a .d.ts with +ServingEndpointRegistry module augmentation. + +Endpoint discovery order: +1. Explicit `endpoints` option (override) +2. AST extraction from server file (server/index.ts or server/server.ts) +3. DATABRICKS_SERVING_ENDPOINT env var (single default endpoint) + +## Parameters + +| Parameter | Type | +| ------ | ------ | +| `options?` | `AppKitServingTypesPluginOptions` | + +## Returns + +`Plugin$1` diff --git a/docs/docs/api/appkit/Function.extractServingEndpoints.md b/docs/docs/api/appkit/Function.extractServingEndpoints.md new file mode 100644 index 00000000..24a5b00d --- /dev/null +++ b/docs/docs/api/appkit/Function.extractServingEndpoints.md @@ -0,0 +1,24 @@ +# Function: extractServingEndpoints() + +```ts +function extractServingEndpoints(serverFilePath: string): + | Record + | null; +``` + +Extract serving endpoint config from a server file by AST-parsing it. +Looks for `serving({ endpoints: { alias: { env: "..." }, ... } })` calls +and extracts the endpoint alias names and their environment variable mappings. + +## Parameters + +| Parameter | Type | Description | +| ------ | ------ | ------ | +| `serverFilePath` | `string` | Absolute path to the server entry file | + +## Returns + + \| `Record`\<`string`, [`EndpointConfig`](Interface.EndpointConfig.md)\> + \| `null` + +Extracted endpoint config, or null if not found or not extractable diff --git a/docs/docs/api/appkit/Function.findServerFile.md b/docs/docs/api/appkit/Function.findServerFile.md new file mode 100644 index 00000000..2ed4e268 --- /dev/null +++ b/docs/docs/api/appkit/Function.findServerFile.md @@ -0,0 +1,19 @@ +# Function: findServerFile() + +```ts +function findServerFile(basePath: string): string | null; +``` + +Find the server entry file by checking candidate paths in order. + +## Parameters + +| Parameter | Type | Description | +| ------ | ------ | ------ | +| `basePath` | `string` | Project root directory to search from | + +## Returns + +`string` \| `null` + +Absolute path to the server file, or null if none found diff --git a/docs/docs/api/appkit/index.md b/docs/docs/api/appkit/index.md index f4685e04..faadf237 100644 --- a/docs/docs/api/appkit/index.md +++ b/docs/docs/api/appkit/index.md @@ -70,9 +70,12 @@ plugin architecture, and React integration. | Function | Description | | ------ | ------ | +| [appKitServingTypesPlugin](Function.appKitServingTypesPlugin.md) | Vite plugin to generate TypeScript types for AppKit serving endpoints. Fetches OpenAPI schemas from Databricks and generates a .d.ts with ServingEndpointRegistry module augmentation. | | [appKitTypesPlugin](Function.appKitTypesPlugin.md) | Vite plugin to generate types for AppKit queries. Calls generateFromEntryPoint under the hood. | | [createApp](Function.createApp.md) | Bootstraps AppKit with the provided configuration. | | [createLakebasePool](Function.createLakebasePool.md) | Create a Lakebase pool with appkit's logger integration. Telemetry automatically uses appkit's OpenTelemetry configuration via global registry. | +| [extractServingEndpoints](Function.extractServingEndpoints.md) | Extract serving endpoint config from a server file by AST-parsing it. Looks for `serving({ endpoints: { alias: { env: "..." }, ... } })` calls and extracts the endpoint alias names and their environment variable mappings. | +| [findServerFile](Function.findServerFile.md) | Find the server entry file by checking candidate paths in order. | | [generateDatabaseCredential](Function.generateDatabaseCredential.md) | Generate OAuth credentials for Postgres database connection using the proper Postgres API. | | [getExecutionContext](Function.getExecutionContext.md) | Get the current execution context. | | [getLakebaseOrmConfig](Function.getLakebaseOrmConfig.md) | Get Lakebase connection configuration for ORMs that don't accept pg.Pool directly. | diff --git a/docs/docs/api/appkit/typedoc-sidebar.ts b/docs/docs/api/appkit/typedoc-sidebar.ts index 91815e3d..1d498d1a 100644 --- a/docs/docs/api/appkit/typedoc-sidebar.ts +++ b/docs/docs/api/appkit/typedoc-sidebar.ts @@ -225,6 +225,11 @@ const typedocSidebar: SidebarsConfig = { type: "category", label: "Functions", items: [ + { + type: "doc", + id: "api/appkit/Function.appKitServingTypesPlugin", + label: "appKitServingTypesPlugin" + }, { type: "doc", id: "api/appkit/Function.appKitTypesPlugin", @@ -240,6 +245,16 @@ const typedocSidebar: SidebarsConfig = { id: "api/appkit/Function.createLakebasePool", label: "createLakebasePool" }, + { + type: "doc", + id: "api/appkit/Function.extractServingEndpoints", + label: "extractServingEndpoints" + }, + { + type: "doc", + id: "api/appkit/Function.findServerFile", + label: "findServerFile" + }, { type: "doc", id: "api/appkit/Function.generateDatabaseCredential", diff --git a/packages/appkit-ui/src/react/hooks/__tests__/use-serving-invoke.test.ts b/packages/appkit-ui/src/react/hooks/__tests__/use-serving-invoke.test.ts new file mode 100644 index 00000000..6d5f159f --- /dev/null +++ b/packages/appkit-ui/src/react/hooks/__tests__/use-serving-invoke.test.ts @@ -0,0 +1,117 @@ +import { act, renderHook, waitFor } from "@testing-library/react"; +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { useServingInvoke } from "../use-serving-invoke"; + +describe("useServingInvoke", () => { + beforeEach(() => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify({ choices: [] }), { status: 200 }), + ); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + test("initial state is idle", () => { + const { result } = renderHook(() => useServingInvoke({ messages: [] })); + + expect(result.current.data).toBeNull(); + expect(result.current.loading).toBe(false); + expect(result.current.error).toBeNull(); + expect(typeof result.current.invoke).toBe("function"); + }); + + test("calls fetch to correct URL on invoke", async () => { + const fetchSpy = vi.spyOn(globalThis, "fetch"); + + const { result } = renderHook(() => + useServingInvoke({ messages: [{ role: "user", content: "Hello" }] }), + ); + + act(() => { + result.current.invoke(); + }); + + await waitFor(() => { + expect(fetchSpy).toHaveBeenCalledWith( + "/api/serving/invoke", + expect.objectContaining({ + method: "POST", + body: JSON.stringify({ + messages: [{ role: "user", content: "Hello" }], + }), + }), + ); + }); + }); + + test("uses alias in URL when provided", async () => { + const fetchSpy = vi.spyOn(globalThis, "fetch"); + + const { result } = renderHook(() => + useServingInvoke({ messages: [] }, { alias: "llm" }), + ); + + act(() => { + result.current.invoke(); + }); + + await waitFor(() => { + expect(fetchSpy).toHaveBeenCalledWith( + "/api/serving/llm/invoke", + expect.any(Object), + ); + }); + }); + + test("sets data on successful response", async () => { + const responseData = { + choices: [{ message: { content: "Hi" } }], + }; + + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify(responseData), { status: 200 }), + ); + + const { result } = renderHook(() => useServingInvoke({ messages: [] })); + + act(() => { + result.current.invoke(); + }); + + await waitFor(() => { + expect(result.current.data).toEqual(responseData); + expect(result.current.loading).toBe(false); + }); + }); + + test("sets error on failed response", async () => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify({ error: "Not found" }), { status: 404 }), + ); + + const { result } = renderHook(() => useServingInvoke({ messages: [] })); + + await act(async () => { + result.current.invoke(); + // Wait for the fetch promise chain to resolve + await new Promise((r) => setTimeout(r, 10)); + }); + + await waitFor(() => { + expect(result.current.error).toBe("Not found"); + expect(result.current.loading).toBe(false); + }); + }); + + test("auto starts when autoStart is true", async () => { + const fetchSpy = vi.spyOn(globalThis, "fetch"); + + renderHook(() => useServingInvoke({ messages: [] }, { autoStart: true })); + + await waitFor(() => { + expect(fetchSpy).toHaveBeenCalled(); + }); + }); +}); diff --git a/packages/appkit-ui/src/react/hooks/__tests__/use-serving-stream.test.ts b/packages/appkit-ui/src/react/hooks/__tests__/use-serving-stream.test.ts new file mode 100644 index 00000000..0a1a736c --- /dev/null +++ b/packages/appkit-ui/src/react/hooks/__tests__/use-serving-stream.test.ts @@ -0,0 +1,271 @@ +import { act, renderHook, waitFor } from "@testing-library/react"; +import { afterEach, describe, expect, test, vi } from "vitest"; + +// Mock connectSSE — capture callbacks so we can simulate SSE events +let capturedCallbacks: { + onMessage?: (msg: { data: string }) => void; + onError?: (err: Error) => void; + signal?: AbortSignal; +} = {}; + +let resolveStream: (() => void) | null = null; + +const mockConnectSSE = vi.fn().mockImplementation((opts: any) => { + capturedCallbacks = { + onMessage: opts.onMessage, + onError: opts.onError, + signal: opts.signal, + }; + return new Promise((resolve) => { + resolveStream = resolve; + // Also resolve after a tick as fallback for tests that don't manually resolve + setTimeout(resolve, 0); + }); +}); + +vi.mock("@/js", () => ({ + connectSSE: (...args: unknown[]) => mockConnectSSE(...args), +})); + +import { useServingStream } from "../use-serving-stream"; + +describe("useServingStream", () => { + afterEach(() => { + capturedCallbacks = {}; + resolveStream = null; + vi.clearAllMocks(); + }); + + test("initial state is idle", () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + expect(result.current.chunks).toEqual([]); + expect(result.current.streaming).toBe(false); + expect(result.current.error).toBeNull(); + expect(typeof result.current.stream).toBe("function"); + expect(typeof result.current.reset).toBe("function"); + }); + + test("calls connectSSE with correct URL on stream", () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + act(() => { + result.current.stream(); + }); + + expect(mockConnectSSE).toHaveBeenCalledWith( + expect.objectContaining({ + url: "/api/serving/stream", + payload: JSON.stringify({ messages: [] }), + }), + ); + }); + + test("uses alias in URL when provided", () => { + const { result } = renderHook(() => + useServingStream({ messages: [] }, { alias: "embedder" }), + ); + + act(() => { + result.current.stream(); + }); + + expect(mockConnectSSE).toHaveBeenCalledWith( + expect.objectContaining({ + url: "/api/serving/embedder/stream", + }), + ); + }); + + test("sets streaming to true when stream() is called", () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + act(() => { + result.current.stream(); + }); + + expect(result.current.streaming).toBe(true); + }); + + test("accumulates chunks from onMessage", async () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + act(() => { + result.current.stream(); + }); + + act(() => { + capturedCallbacks.onMessage?.({ data: JSON.stringify({ id: 1 }) }); + }); + + act(() => { + capturedCallbacks.onMessage?.({ data: JSON.stringify({ id: 2 }) }); + }); + + expect(result.current.chunks).toEqual([{ id: 1 }, { id: 2 }]); + }); + + test("accumulates chunks with error field as normal data", async () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + act(() => { + result.current.stream(); + }); + + act(() => { + capturedCallbacks.onMessage?.({ + data: JSON.stringify({ error: "Model overloaded" }), + }); + }); + + // Chunks with an `error` field are treated as data, not stream errors. + // Transport-level errors are delivered via onError callback instead. + expect(result.current.chunks).toEqual([{ error: "Model overloaded" }]); + expect(result.current.error).toBeNull(); + expect(result.current.streaming).toBe(true); + }); + + test("sets error from onError callback", async () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + act(() => { + result.current.stream(); + }); + + act(() => { + capturedCallbacks.onError?.(new Error("Connection lost")); + }); + + expect(result.current.error).toBe("Connection lost"); + expect(result.current.streaming).toBe(false); + }); + + test("silently skips malformed JSON messages", () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + act(() => { + result.current.stream(); + }); + + act(() => { + capturedCallbacks.onMessage?.({ data: "not valid json{" }); + }); + + // No chunks added, no error set + expect(result.current.chunks).toEqual([]); + expect(result.current.error).toBeNull(); + }); + + test("reset() clears state and aborts active stream", () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + act(() => { + result.current.stream(); + }); + + act(() => { + capturedCallbacks.onMessage?.({ data: JSON.stringify({ id: 1 }) }); + }); + + expect(result.current.chunks).toHaveLength(1); + expect(result.current.streaming).toBe(true); + + act(() => { + result.current.reset(); + }); + + expect(result.current.chunks).toEqual([]); + expect(result.current.streaming).toBe(false); + expect(result.current.error).toBeNull(); + }); + + test("autoStart triggers stream on mount", async () => { + renderHook(() => useServingStream({ messages: [] }, { autoStart: true })); + + await waitFor(() => { + expect(mockConnectSSE).toHaveBeenCalled(); + }); + }); + + test("passes abort signal to connectSSE", () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + act(() => { + result.current.stream(); + }); + + expect(capturedCallbacks.signal).toBeDefined(); + expect(capturedCallbacks.signal?.aborted).toBe(false); + }); + + test("aborts stream on unmount", () => { + const { result, unmount } = renderHook(() => + useServingStream({ messages: [] }), + ); + + act(() => { + result.current.stream(); + }); + + const signal = capturedCallbacks.signal; + expect(signal?.aborted).toBe(false); + + unmount(); + + expect(signal?.aborted).toBe(true); + }); + + test("sets streaming to false when connectSSE resolves", async () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + act(() => { + result.current.stream(); + }); + + await waitFor(() => { + expect(result.current.streaming).toBe(false); + }); + }); + + test("calls onComplete with accumulated chunks when stream finishes", async () => { + const onComplete = vi.fn(); + + // Use a controllable mock so stream doesn't auto-resolve + mockConnectSSE.mockImplementationOnce((opts: any) => { + capturedCallbacks = { + onMessage: opts.onMessage, + onError: opts.onError, + signal: opts.signal, + }; + return new Promise((resolve) => { + resolveStream = resolve; + }); + }); + + const { result } = renderHook(() => + useServingStream({ messages: [] }, { onComplete }), + ); + + act(() => { + result.current.stream(); + }); + + // Send two chunks + act(() => { + capturedCallbacks.onMessage?.({ data: JSON.stringify({ id: 1 }) }); + }); + act(() => { + capturedCallbacks.onMessage?.({ data: JSON.stringify({ id: 2 }) }); + }); + + expect(onComplete).not.toHaveBeenCalled(); + + // Complete the stream + await act(async () => { + resolveStream?.(); + await new Promise((r) => setTimeout(r, 0)); + }); + + expect(onComplete).toHaveBeenCalledWith([{ id: 1 }, { id: 2 }]); + }); +}); diff --git a/packages/appkit-ui/src/react/hooks/index.ts b/packages/appkit-ui/src/react/hooks/index.ts index 84d51b53..a425b010 100644 --- a/packages/appkit-ui/src/react/hooks/index.ts +++ b/packages/appkit-ui/src/react/hooks/index.ts @@ -2,8 +2,13 @@ export type { AnalyticsFormat, InferResultByFormat, InferRowType, + InferServingChunk, + InferServingRequest, + InferServingResponse, PluginRegistry, QueryRegistry, + ServingAlias, + ServingEndpointRegistry, TypedArrowTable, UseAnalyticsQueryOptions, UseAnalyticsQueryResult, @@ -15,3 +20,13 @@ export { useChartData, } from "./use-chart-data"; export { usePluginClientConfig } from "./use-plugin-config"; +export { + type UseServingInvokeOptions, + type UseServingInvokeResult, + useServingInvoke, +} from "./use-serving-invoke"; +export { + type UseServingStreamOptions, + type UseServingStreamResult, + useServingStream, +} from "./use-serving-stream"; diff --git a/packages/appkit-ui/src/react/hooks/types.ts b/packages/appkit-ui/src/react/hooks/types.ts index 5db725fc..19ce1fac 100644 --- a/packages/appkit-ui/src/react/hooks/types.ts +++ b/packages/appkit-ui/src/react/hooks/types.ts @@ -134,3 +134,54 @@ export type InferParams = K extends AugmentedRegistry export interface PluginRegistry { [key: string]: Record; } + +// ============================================================================ +// Serving Endpoint Registry +// ============================================================================ + +/** + * Serving endpoint registry for type-safe alias names. + * Extend this interface via module augmentation to get alias autocomplete: + * + * @example + * ```typescript + * // Auto-generated by appKitServingTypesPlugin() + * declare module "@databricks/appkit-ui/react" { + * interface ServingEndpointRegistry { + * llm: { request: {...}; response: {...}; chunk: {...} }; + * } + * } + * ``` + */ +// biome-ignore lint/suspicious/noEmptyInterface: intentionally empty — populated via module augmentation +export interface ServingEndpointRegistry {} + +/** Resolves to registry keys if populated, otherwise string */ +export type ServingAlias = + AugmentedRegistry extends never + ? string + : AugmentedRegistry; + +/** Infers chunk type from registry when alias is a known key */ +export type InferServingChunk = + K extends AugmentedRegistry + ? ServingEndpointRegistry[K] extends { chunk: infer C } + ? C + : unknown + : unknown; + +/** Infers response type from registry when alias is a known key */ +export type InferServingResponse = + K extends AugmentedRegistry + ? ServingEndpointRegistry[K] extends { response: infer R } + ? R + : unknown + : unknown; + +/** Infers request type from registry when alias is a known key */ +export type InferServingRequest = + K extends AugmentedRegistry + ? ServingEndpointRegistry[K] extends { request: infer Req } + ? Req + : Record + : Record; diff --git a/packages/appkit-ui/src/react/hooks/use-serving-invoke.ts b/packages/appkit-ui/src/react/hooks/use-serving-invoke.ts new file mode 100644 index 00000000..343a5e71 --- /dev/null +++ b/packages/appkit-ui/src/react/hooks/use-serving-invoke.ts @@ -0,0 +1,103 @@ +import { useCallback, useEffect, useRef, useState } from "react"; +import type { + InferServingRequest, + InferServingResponse, + ServingAlias, +} from "./types"; + +export interface UseServingInvokeOptions< + K extends ServingAlias = ServingAlias, +> { + /** Endpoint alias for named mode. Omit for default mode. */ + alias?: K; + /** If false, does not invoke automatically on mount. Default: false */ + autoStart?: boolean; +} + +export interface UseServingInvokeResult { + /** Trigger the invocation. Returns the response data, or null on error/abort. */ + invoke: () => Promise; + /** Response data, null until loaded. */ + data: T | null; + /** Whether a request is in progress. */ + loading: boolean; + /** Error message, if any. */ + error: string | null; +} + +/** + * Hook for non-streaming invocation of a serving endpoint. + * Calls `POST /api/serving/invoke` (default) or `POST /api/serving/{alias}/invoke` (named). + * + * When the type generator has populated `ServingEndpointRegistry`, the response type + * is automatically inferred from the endpoint's OpenAPI schema. + */ +export function useServingInvoke( + body: InferServingRequest, + options: UseServingInvokeOptions = {} as UseServingInvokeOptions, +): UseServingInvokeResult> { + type TResponse = InferServingResponse; + const { alias, autoStart = false } = options; + + const [data, setData] = useState(null); + const [loading, setLoading] = useState(false); + const [error, setError] = useState(null); + const abortControllerRef = useRef(null); + + const urlSuffix = alias + ? `/api/serving/${encodeURIComponent(String(alias))}/invoke` + : "/api/serving/invoke"; + + const bodyJson = JSON.stringify(body); + + const invoke = useCallback((): Promise => { + if (abortControllerRef.current) { + abortControllerRef.current.abort(); + } + + setLoading(true); + setError(null); + setData(null); + + const abortController = new AbortController(); + abortControllerRef.current = abortController; + + return fetch(urlSuffix, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: bodyJson, + signal: abortController.signal, + }) + .then(async (res) => { + if (!res.ok) { + const errorBody = await res.json().catch(() => null); + throw new Error(errorBody?.error || `HTTP ${res.status}`); + } + return res.json(); + }) + .then((result: TResponse) => { + if (abortController.signal.aborted) return null; + setData(result); + setLoading(false); + return result; + }) + .catch((err: Error) => { + if (abortController.signal.aborted) return null; + setError(err.message || "Request failed"); + setLoading(false); + return null; + }); + }, [urlSuffix, bodyJson]); + + useEffect(() => { + if (autoStart) { + invoke(); + } + + return () => { + abortControllerRef.current?.abort(); + }; + }, [invoke, autoStart]); + + return { invoke, data, loading, error }; +} diff --git a/packages/appkit-ui/src/react/hooks/use-serving-stream.ts b/packages/appkit-ui/src/react/hooks/use-serving-stream.ts new file mode 100644 index 00000000..4801d94c --- /dev/null +++ b/packages/appkit-ui/src/react/hooks/use-serving-stream.ts @@ -0,0 +1,123 @@ +import { useCallback, useEffect, useRef, useState } from "react"; +import { connectSSE } from "@/js"; +import type { + InferServingChunk, + InferServingRequest, + ServingAlias, +} from "./types"; + +export interface UseServingStreamOptions< + K extends ServingAlias = ServingAlias, + T = InferServingChunk, +> { + /** Endpoint alias for named mode. Omit for default mode. */ + alias?: K; + /** If true, starts streaming automatically on mount. Default: false */ + autoStart?: boolean; + /** Called with accumulated chunks when the stream completes successfully. */ + onComplete?: (chunks: T[]) => void; +} + +export interface UseServingStreamResult { + /** Trigger the streaming invocation. */ + stream: () => void; + /** Accumulated chunks received so far. */ + chunks: T[]; + /** Whether streaming is in progress. */ + streaming: boolean; + /** Error message, if any. */ + error: string | null; + /** Reset chunks and abort any active stream. */ + reset: () => void; +} + +/** + * Hook for streaming invocation of a serving endpoint via SSE. + * Calls `POST /api/serving/stream` (default) or `POST /api/serving/{alias}/stream` (named). + * Accumulates parsed chunks in state. + * + * When the type generator has populated `ServingEndpointRegistry`, the chunk type + * is automatically inferred from the endpoint's OpenAPI schema. + */ +export function useServingStream( + body: InferServingRequest, + options: UseServingStreamOptions = {} as UseServingStreamOptions, +): UseServingStreamResult> { + type TChunk = InferServingChunk; + const { alias, autoStart = false, onComplete } = options; + + const [chunks, setChunks] = useState([]); + const [streaming, setStreaming] = useState(false); + const [error, setError] = useState(null); + const abortControllerRef = useRef(null); + const chunksRef = useRef([]); + const onCompleteRef = useRef(onComplete); + onCompleteRef.current = onComplete; + + const urlSuffix = alias + ? `/api/serving/${encodeURIComponent(String(alias))}/stream` + : "/api/serving/stream"; + + const reset = useCallback(() => { + abortControllerRef.current?.abort(); + abortControllerRef.current = null; + chunksRef.current = []; + setChunks([]); + setStreaming(false); + setError(null); + }, []); + + const bodyJson = JSON.stringify(body); + + const stream = useCallback(() => { + // Abort any existing stream + abortControllerRef.current?.abort(); + + setStreaming(true); + setError(null); + setChunks([]); + chunksRef.current = []; + + const abortController = new AbortController(); + abortControllerRef.current = abortController; + + connectSSE({ + url: urlSuffix, + payload: bodyJson, + signal: abortController.signal, + onMessage: async (message) => { + if (abortController.signal.aborted) return; + try { + const parsed = JSON.parse(message.data); + + chunksRef.current = [...chunksRef.current, parsed as TChunk]; + setChunks(chunksRef.current); + } catch { + // Skip malformed messages + } + }, + onError: (err) => { + if (abortController.signal.aborted) return; + setStreaming(false); + setError(err instanceof Error ? err.message : "Streaming failed"); + }, + }).then(() => { + if (abortController.signal.aborted) return; + // Stream completed + setStreaming(false); + onCompleteRef.current?.(chunksRef.current); + }); + }, [urlSuffix, bodyJson]); + + useEffect(() => { + if (autoStart) { + stream(); + } + + return () => { + abortControllerRef.current?.abort(); + }; + }, [stream, autoStart]); + + return { stream, chunks, streaming, error, reset }; +} diff --git a/packages/appkit/package.json b/packages/appkit/package.json index 9e810b97..06da3ee1 100644 --- a/packages/appkit/package.json +++ b/packages/appkit/package.json @@ -50,6 +50,7 @@ "typecheck": "tsc --noEmit" }, "dependencies": { + "@ast-grep/napi": "0.37.0", "@databricks/lakebase": "workspace:*", "@databricks/sdk-experimental": "0.16.0", "@opentelemetry/api": "1.9.0", diff --git a/packages/appkit/src/index.ts b/packages/appkit/src/index.ts index 662a9178..3df5572b 100644 --- a/packages/appkit/src/index.ts +++ b/packages/appkit/src/index.ts @@ -81,6 +81,10 @@ export { SpanStatusCode, type TelemetryConfig, } from "./telemetry"; - +export { + extractServingEndpoints, + findServerFile, +} from "./type-generator/serving/server-file-extractor"; +export { appKitServingTypesPlugin } from "./type-generator/serving/vite-plugin"; // Vite plugin and type generation export { appKitTypesPlugin } from "./type-generator/vite-plugin"; diff --git a/packages/appkit/src/plugins/serving/schema-filter.ts b/packages/appkit/src/plugins/serving/schema-filter.ts index 6e52294a..07683ede 100644 --- a/packages/appkit/src/plugins/serving/schema-filter.ts +++ b/packages/appkit/src/plugins/serving/schema-filter.ts @@ -1,19 +1,9 @@ import fs from "node:fs/promises"; import { createLogger } from "../../logging/logger"; - -const CACHE_VERSION = "1"; - -interface ServingCacheEntry { - hash: string; - requestType: string; - responseType: string; - chunkType: string | null; -} - -interface ServingCache { - version: string; - endpoints: Record; -} +import { + CACHE_VERSION, + type ServingCache, +} from "../../type-generator/serving/cache"; const logger = createLogger("serving:schema-filter"); diff --git a/packages/appkit/src/type-generator/serving/cache.ts b/packages/appkit/src/type-generator/serving/cache.ts new file mode 100644 index 00000000..2737f117 --- /dev/null +++ b/packages/appkit/src/type-generator/serving/cache.ts @@ -0,0 +1,55 @@ +import crypto from "node:crypto"; +import fs from "node:fs/promises"; +import path from "node:path"; +import { createLogger } from "../../logging/logger"; + +const logger = createLogger("type-generator:serving:cache"); + +export const CACHE_VERSION = "1"; +const CACHE_FILE = ".appkit-serving-types-cache.json"; +const CACHE_DIR = path.join( + process.cwd(), + "node_modules", + ".databricks", + "appkit", +); + +export interface ServingCacheEntry { + hash: string; + requestType: string; + responseType: string; + chunkType: string | null; +} + +export interface ServingCache { + version: string; + endpoints: Record; +} + +export function hashSchema(schemaJson: string): string { + return crypto.createHash("sha256").update(schemaJson).digest("hex"); +} + +export async function loadServingCache(): Promise { + const cachePath = path.join(CACHE_DIR, CACHE_FILE); + try { + await fs.mkdir(CACHE_DIR, { recursive: true }); + const raw = await fs.readFile(cachePath, "utf8"); + const cache = JSON.parse(raw) as ServingCache; + if (cache.version === CACHE_VERSION) { + return cache; + } + logger.debug("Cache version mismatch, starting fresh"); + } catch (err) { + if ((err as NodeJS.ErrnoException).code !== "ENOENT") { + logger.warn("Cache file is corrupted, flushing cache completely."); + } + } + return { version: CACHE_VERSION, endpoints: {} }; +} + +export async function saveServingCache(cache: ServingCache): Promise { + const cachePath = path.join(CACHE_DIR, CACHE_FILE); + await fs.mkdir(CACHE_DIR, { recursive: true }); + await fs.writeFile(cachePath, JSON.stringify(cache, null, 2), "utf8"); +} diff --git a/packages/appkit/src/type-generator/serving/converter.ts b/packages/appkit/src/type-generator/serving/converter.ts new file mode 100644 index 00000000..1849e720 --- /dev/null +++ b/packages/appkit/src/type-generator/serving/converter.ts @@ -0,0 +1,149 @@ +import type { OpenApiOperation, OpenApiSchema } from "./fetcher"; + +/** + * Converts an OpenAPI schema to a TypeScript type string. + */ +function schemaToTypeString(schema: OpenApiSchema, indent = 0): string { + const pad = " ".repeat(indent); + + if (schema.oneOf) { + return schema.oneOf.map((s) => schemaToTypeString(s, indent)).join(" | "); + } + + if (schema.enum) { + return schema.enum.map((v) => JSON.stringify(v)).join(" | "); + } + + switch (schema.type) { + case "string": + return "string"; + case "integer": + case "number": + return "number"; + case "boolean": + return "boolean"; + case "array": { + if (!schema.items) return "unknown[]"; + const itemType = schemaToTypeString(schema.items, indent); + // Wrap union types in parens for array + if (itemType.includes(" | ") && !itemType.startsWith("{")) { + return `(${itemType})[]`; + } + return `${itemType}[]`; + } + case "object": { + if (!schema.properties) return "Record"; + const required = new Set(schema.required ?? []); + const entries = Object.entries(schema.properties).map(([key, prop]) => { + const optional = !required.has(key) ? "?" : ""; + const nullable = prop.nullable ? " | null" : ""; + const typeStr = schemaToTypeString(prop, indent + 1); + const formatComment = + prop.format && (prop.type === "number" || prop.type === "integer") + ? `/** @openapi ${prop.format}${prop.nullable ? ", nullable" : ""} */\n${pad} ` + : prop.nullable && prop.type === "integer" + ? `/** @openapi integer, nullable */\n${pad} ` + : ""; + return `${pad} ${formatComment}${key}${optional}: ${typeStr}${nullable};`; + }); + return `{\n${entries.join("\n")}\n${pad}}`; + } + default: + return "unknown"; + } +} + +/** + * Extracts and converts the request schema from an OpenAPI path operation. + * Strips the `stream` property from the request type. + */ +export function convertRequestSchema(operation: OpenApiOperation): string { + const schema = operation.requestBody?.content?.["application/json"]?.schema; + if (!schema || !schema.properties) return "Record"; + + // Strip `stream` property — the plugin controls this + const { stream: _stream, ...filteredProps } = schema.properties; + const filteredRequired = (schema.required ?? []).filter( + (r) => r !== "stream", + ); + + const filteredSchema: OpenApiSchema = { + ...schema, + properties: filteredProps, + required: filteredRequired.length > 0 ? filteredRequired : undefined, + }; + + return schemaToTypeString(filteredSchema); +} + +/** + * Extracts and converts the response schema from an OpenAPI path operation. + */ +export function convertResponseSchema(operation: OpenApiOperation): string { + const response = operation.responses?.["200"]; + const schema = response?.content?.["application/json"]?.schema; + if (!schema) return "unknown"; + return schemaToTypeString(schema); +} + +/** + * Derives a streaming chunk type from the response schema. + * Returns null if the response doesn't follow OpenAI-compatible format. + * + * OpenAI-compatible heuristic: response has `choices` array where items + * have a `message` object property. + */ +export function deriveChunkType(operation: OpenApiOperation): string | null { + const response = operation.responses?.["200"]; + const schema = response?.content?.["application/json"]?.schema; + if (!schema?.properties) return null; + + const choicesProp = schema.properties.choices; + if (!choicesProp || choicesProp.type !== "array" || !choicesProp.items) + return null; + + const choiceItemProps = choicesProp.items.properties; + if (!choiceItemProps?.message) return null; + + // It's OpenAI-compatible. Build the chunk type by transforming. + const messageSchema = choiceItemProps.message; + + // Build chunk schema: replace message with delta (Partial), make finish_reason nullable, drop usage + const chunkProperties: Record = {}; + + for (const [key, prop] of Object.entries(schema.properties)) { + if (key === "usage") continue; // Drop usage from chunks + if (key === "choices") { + // Transform choices items + const chunkChoiceProps: Record = {}; + for (const [ck, cp] of Object.entries(choiceItemProps)) { + if (ck === "message") { + // Replace message with delta: Partial + chunkChoiceProps.delta = { ...messageSchema }; + } else if (ck === "finish_reason") { + chunkChoiceProps[ck] = { ...cp, nullable: true }; + } else { + chunkChoiceProps[ck] = cp; + } + } + chunkProperties[key] = { + type: "array", + items: { + type: "object", + properties: chunkChoiceProps, + }, + }; + } else { + chunkProperties[key] = prop; + } + } + + const chunkSchema: OpenApiSchema = { + type: "object", + properties: chunkProperties, + }; + + // Delta properties are already optional (no `required` array in the schema), + // so schemaToTypeString renders them with `?:` — no Partial<> wrapper needed. + return schemaToTypeString(chunkSchema); +} diff --git a/packages/appkit/src/type-generator/serving/fetcher.ts b/packages/appkit/src/type-generator/serving/fetcher.ts new file mode 100644 index 00000000..bf733d7b --- /dev/null +++ b/packages/appkit/src/type-generator/serving/fetcher.ts @@ -0,0 +1,158 @@ +import type { WorkspaceClient } from "@databricks/sdk-experimental"; +import { createLogger } from "../../logging/logger"; + +const logger = createLogger("type-generator:serving:fetcher"); + +interface OpenApiSpec { + openapi: string; + info: { title: string; version: string }; + paths: Record>; +} + +export interface OpenApiOperation { + requestBody?: { + content: { + "application/json": { + schema: OpenApiSchema; + }; + }; + }; + responses?: Record< + string, + { + content?: { + "application/json": { + schema: OpenApiSchema; + }; + }; + } + >; +} + +export interface OpenApiSchema { + type?: string; + properties?: Record; + required?: string[]; + items?: OpenApiSchema; + enum?: string[]; + nullable?: boolean; + oneOf?: OpenApiSchema[]; + format?: string; +} + +/** + * Fetches the OpenAPI schema for a serving endpoint. + * Returns null if the endpoint is not found or access is denied. + */ +export async function fetchOpenApiSchema( + client: WorkspaceClient, + endpointName: string, + servedModel?: string, +): Promise<{ spec: OpenApiSpec; pathKey: string } | null> { + const headers = new Headers({ Accept: "application/json" }); + await client.config.authenticate(headers); + + const host = client.config.host; + if (!host) { + logger.warn("Databricks host not configured, skipping schema fetch"); + return null; + } + + const base = host.startsWith("http") ? host : `https://${host}`; + const url = new URL( + `/api/2.0/serving-endpoints/${encodeURIComponent(endpointName)}/openapi`, + base, + ); + + const controller = new AbortController(); + const timeout = setTimeout(() => controller.abort(), 5000); + + try { + const res = await fetch(url.toString(), { + headers, + signal: controller.signal, + }); + + if (!res.ok) { + const body = await res.text().catch(() => ""); + if (res.status === 404) { + logger.warn( + "Endpoint '%s' not found, skipping type generation%s", + endpointName, + body ? `: ${body}` : "", + ); + } else if (res.status === 403) { + logger.warn( + "Access denied to endpoint '%s' schema, skipping type generation%s", + endpointName, + body ? `: ${body}` : "", + ); + } else { + logger.warn( + "Failed to fetch schema for '%s' (HTTP %d), skipping%s", + endpointName, + res.status, + body ? `: ${body}` : "", + ); + } + return null; + } + + const rawSpec: unknown = await res.json(); + if ( + typeof rawSpec !== "object" || + rawSpec === null || + !("paths" in rawSpec) || + typeof (rawSpec as OpenApiSpec).paths !== "object" + ) { + logger.warn( + "Invalid OpenAPI schema structure for '%s', skipping", + endpointName, + ); + return null; + } + const spec = rawSpec as OpenApiSpec; + + // Find the right path key + const pathKeys = Object.keys(spec.paths ?? {}); + if (pathKeys.length === 0) { + logger.warn("No paths in OpenAPI schema for '%s'", endpointName); + return null; + } + + let pathKey: string; + if (servedModel) { + const match = pathKeys.find((k) => k.includes(`/${servedModel}/`)); + if (!match) { + logger.warn( + "Served model '%s' not found in schema for '%s', using first path", + servedModel, + endpointName, + ); + pathKey = pathKeys[0]; + } else { + pathKey = match; + } + } else { + pathKey = pathKeys[0]; + } + + return { spec, pathKey }; + } catch (err) { + if ((err as Error).name === "AbortError") { + logger.warn( + "Timeout fetching schema for '%s', skipping type generation", + endpointName, + ); + } else { + logger.warn( + "Error fetching schema for '%s': %s", + endpointName, + (err as Error).message, + ); + } + return null; + } finally { + clearTimeout(timeout); + } +} diff --git a/packages/appkit/src/type-generator/serving/generator.ts b/packages/appkit/src/type-generator/serving/generator.ts new file mode 100644 index 00000000..44026f89 --- /dev/null +++ b/packages/appkit/src/type-generator/serving/generator.ts @@ -0,0 +1,266 @@ +import fs from "node:fs/promises"; +import { WorkspaceClient } from "@databricks/sdk-experimental"; +import pc from "picocolors"; +import { createLogger } from "../../logging/logger"; +import type { EndpointConfig } from "../../plugins/serving/types"; +import { + CACHE_VERSION, + hashSchema, + loadServingCache, + type ServingCache, + saveServingCache, +} from "./cache"; +import { + convertRequestSchema, + convertResponseSchema, + deriveChunkType, +} from "./converter"; +import { fetchOpenApiSchema } from "./fetcher"; + +const logger = createLogger("type-generator:serving"); + +const GENERIC_REQUEST = "Record"; +const GENERIC_RESPONSE = "unknown"; +const GENERIC_CHUNK = "unknown"; + +interface GenerateServingTypesOptions { + outFile: string; + endpoints?: Record; + noCache?: boolean; +} + +/** + * Generates TypeScript type declarations for serving endpoints + * by fetching their OpenAPI schemas and converting to TypeScript. + */ +export async function generateServingTypes( + options: GenerateServingTypesOptions, +): Promise { + const { outFile, noCache } = options; + + // Resolve endpoints from config or env + const endpoints = options.endpoints ?? resolveDefaultEndpoints(); + if (Object.keys(endpoints).length === 0) { + logger.debug("No serving endpoints configured, skipping type generation"); + return; + } + + const startTime = performance.now(); + + const cache = noCache + ? { version: CACHE_VERSION, endpoints: {} } + : await loadServingCache(); + + const client = new WorkspaceClient({}); + let updated = false; + + const registryEntries: string[] = []; + const logEntries: Array<{ + alias: string; + status: "HIT" | "MISS"; + error?: string; + }> = []; + + for (const [alias, config] of Object.entries(endpoints)) { + const endpointName = process.env[config.env]; + if (!endpointName) { + registryEntries.push( + buildRegistryEntry( + alias, + GENERIC_REQUEST, + GENERIC_RESPONSE, + GENERIC_CHUNK, + ), + ); + logEntries.push({ + alias, + status: "MISS", + error: `env ${config.env} not set`, + }); + continue; + } + + const result = await fetchOpenApiSchema( + client, + endpointName, + config.servedModel, + ); + if (!result) { + registryEntries.push( + buildRegistryEntry( + alias, + GENERIC_REQUEST, + GENERIC_RESPONSE, + GENERIC_CHUNK, + ), + ); + logEntries.push({ + alias, + status: "MISS", + error: "schema fetch failed", + }); + continue; + } + + const { spec, pathKey } = result; + const schemaJson = JSON.stringify(spec); + const hash = hashSchema(schemaJson); + + // Check cache + const cached = cache.endpoints[alias]; + if (cached && cached.hash === hash) { + registryEntries.push( + buildRegistryEntry( + alias, + cached.requestType, + cached.responseType, + cached.chunkType, + ), + ); + logEntries.push({ alias, status: "HIT" }); + continue; + } + + // Cache miss — convert + const operation = spec.paths[pathKey]?.post; + if (!operation) { + logEntries.push({ + alias, + status: "MISS", + error: "no POST operation", + }); + continue; + } + + let requestType: string; + let responseType: string; + let chunkType: string | null; + try { + requestType = convertRequestSchema(operation); + responseType = convertResponseSchema(operation); + chunkType = deriveChunkType(operation); + } catch (convErr) { + logger.warn( + "Schema conversion failed for '%s': %s", + alias, + (convErr as Error).message, + ); + registryEntries.push( + buildRegistryEntry( + alias, + GENERIC_REQUEST, + GENERIC_RESPONSE, + GENERIC_CHUNK, + ), + ); + logEntries.push({ + alias, + status: "MISS", + error: "schema conversion failed", + }); + continue; + } + + cache.endpoints[alias] = { hash, requestType, responseType, chunkType }; + updated = true; + + registryEntries.push( + buildRegistryEntry(alias, requestType, responseType, chunkType), + ); + logEntries.push({ alias, status: "MISS" }); + } + + // Print formatted table (matching analytics typegen output) + if (logEntries.length > 0) { + const maxNameLen = Math.max(...logEntries.map((e) => e.alias.length)); + const separator = pc.dim("─".repeat(50)); + console.log(""); + console.log( + ` ${pc.bold("Typegen Serving")} ${pc.dim(`(${logEntries.length})`)}`, + ); + console.log(` ${separator}`); + for (const entry of logEntries) { + const tag = + entry.status === "HIT" + ? `cache ${pc.bold(pc.green("HIT "))}` + : `cache ${pc.bold(pc.yellow("MISS "))}`; + const rawName = entry.alias.padEnd(maxNameLen); + const reason = entry.error ? ` ${pc.dim(entry.error)}` : ""; + console.log(` ${tag} ${rawName}${reason}`); + } + const elapsed = ((performance.now() - startTime) / 1000).toFixed(2); + const newCount = logEntries.filter((e) => e.status === "MISS").length; + const cacheCount = logEntries.filter((e) => e.status === "HIT").length; + console.log(` ${separator}`); + console.log( + ` ${newCount} new, ${cacheCount} from cache. ${pc.dim(`${elapsed}s`)}`, + ); + console.log(""); + } + + const output = generateTypeDeclarations(registryEntries); + await fs.writeFile(outFile, output, "utf-8"); + + if (registryEntries.length === 0) { + logger.debug( + "Wrote empty serving types to %s (no endpoints resolved)", + outFile, + ); + } else { + logger.debug("Wrote serving types to %s", outFile); + } + + if (updated) { + await saveServingCache(cache as ServingCache); + } +} + +function resolveDefaultEndpoints(): Record { + if (process.env.DATABRICKS_SERVING_ENDPOINT) { + return { default: { env: "DATABRICKS_SERVING_ENDPOINT" } }; + } + return {}; +} + +function buildRegistryEntry( + alias: string, + requestType: string, + responseType: string, + chunkType: string | null, +): string { + const indent = " "; + const chunkEntry = chunkType ? chunkType : "unknown"; + return ` ${alias}: { +${indent}request: ${indentType(requestType, indent)}; +${indent}response: ${indentType(responseType, indent)}; +${indent}chunk: ${indentType(chunkEntry, indent)}; + };`; +} + +function indentType(typeStr: string, baseIndent: string): string { + if (!typeStr.includes("\n")) return typeStr; + return typeStr + .split("\n") + .map((line, i) => (i === 0 ? line : `${baseIndent}${line}`)) + .join("\n"); +} + +function generateTypeDeclarations(entries: string[]): string { + return `// Auto-generated by AppKit - DO NOT EDIT +// Generated from serving endpoint OpenAPI schemas +import "@databricks/appkit"; +import "@databricks/appkit-ui/react"; + +declare module "@databricks/appkit" { + interface ServingEndpointRegistry { +${entries.join("\n")} + } +} + +declare module "@databricks/appkit-ui/react" { + interface ServingEndpointRegistry { +${entries.join("\n")} + } +} +`; +} diff --git a/packages/appkit/src/type-generator/serving/server-file-extractor.ts b/packages/appkit/src/type-generator/serving/server-file-extractor.ts new file mode 100644 index 00000000..cb1fbe7e --- /dev/null +++ b/packages/appkit/src/type-generator/serving/server-file-extractor.ts @@ -0,0 +1,221 @@ +import fs from "node:fs"; +import path from "node:path"; +import { Lang, parse, type SgNode } from "@ast-grep/napi"; +import { createLogger } from "../../logging/logger"; +import type { EndpointConfig } from "../../plugins/serving/types"; + +const logger = createLogger("type-generator:serving:extractor"); + +/** + * Candidate paths for the server entry file, relative to the project root. + * Checked in order; the first that exists is used. + * Same convention as plugin sync (sync.ts SERVER_FILE_CANDIDATES). + */ +const SERVER_FILE_CANDIDATES = ["server/index.ts", "server/server.ts"]; + +/** + * Find the server entry file by checking candidate paths in order. + * + * @param basePath - Project root directory to search from + * @returns Absolute path to the server file, or null if none found + */ +export function findServerFile(basePath: string): string | null { + for (const candidate of SERVER_FILE_CANDIDATES) { + const fullPath = path.join(basePath, candidate); + if (fs.existsSync(fullPath)) { + return fullPath; + } + } + return null; +} + +/** + * Extract serving endpoint config from a server file by AST-parsing it. + * Looks for `serving({ endpoints: { alias: { env: "..." }, ... } })` calls + * and extracts the endpoint alias names and their environment variable mappings. + * + * @param serverFilePath - Absolute path to the server entry file + * @returns Extracted endpoint config, or null if not found or not extractable + */ +export function extractServingEndpoints( + serverFilePath: string, +): Record | null { + let content: string; + try { + content = fs.readFileSync(serverFilePath, "utf-8"); + } catch { + logger.debug("Could not read server file: %s", serverFilePath); + return null; + } + + const lang = serverFilePath.endsWith(".tsx") ? Lang.Tsx : Lang.TypeScript; + const ast = parse(lang, content); + const root = ast.root(); + + // Find serving(...) call expressions + const servingCall = findServingCall(root); + if (!servingCall) { + logger.debug("No serving() call found in %s", serverFilePath); + return null; + } + + // Get the first argument (the config object) + const args = servingCall.field("arguments"); + if (!args) { + return null; + } + + const configArg = args.children().find((child) => child.kind() === "object"); + if (!configArg) { + // serving() called with no args or non-object arg + return null; + } + + // Find the "endpoints" property in the config object + const endpointsPair = findProperty(configArg, "endpoints"); + if (!endpointsPair) { + // Config object has no "endpoints" property (e.g. serving({ timeout: 5000 })) + return null; + } + + // Get the value of the endpoints property + const endpointsValue = getPropertyValue(endpointsPair); + if (!endpointsValue || endpointsValue.kind() !== "object") { + // endpoints is a variable reference, not an inline object + logger.debug( + "serving() endpoints is not an inline object literal in %s. " + + "Pass endpoints explicitly via appKitServingTypesPlugin({ endpoints }) in vite.config.ts.", + serverFilePath, + ); + return null; + } + + // Extract each endpoint entry + const endpoints: Record = {}; + const pairs = endpointsValue + .children() + .filter((child) => child.kind() === "pair"); + + for (const pair of pairs) { + const entry = extractEndpointEntry(pair); + if (entry) { + endpoints[entry.alias] = entry.config; + } + } + + if (Object.keys(endpoints).length === 0) { + return null; + } + + logger.debug( + "Extracted %d endpoint(s) from %s: %s", + Object.keys(endpoints).length, + serverFilePath, + Object.keys(endpoints).join(", "), + ); + + return endpoints; +} + +/** + * Find the serving() call expression in the AST. + * Looks for call expressions where the callee identifier is "serving". + */ +function findServingCall(root: SgNode): SgNode | null { + const callExpressions = root.findAll({ + rule: { kind: "call_expression" }, + }); + + for (const call of callExpressions) { + const callee = call.children()[0]; + if (callee?.kind() === "identifier" && callee.text() === "serving") { + return call; + } + } + + return null; +} + +/** + * Find a property (pair node) with the given key name in an object expression. + */ +function findProperty(objectNode: SgNode, propertyName: string): SgNode | null { + const pairs = objectNode + .children() + .filter((child) => child.kind() === "pair"); + + for (const pair of pairs) { + const key = pair.children()[0]; + if (!key) continue; + + const keyText = + key.kind() === "property_identifier" + ? key.text() + : key.kind() === "string" + ? key.text().replace(/^['"]|['"]$/g, "") + : null; + + if (keyText === propertyName) { + return pair; + } + } + + return null; +} + +/** + * Get the value node from a pair (property: value). + * The value is typically the last meaningful child after the colon. + */ +function getPropertyValue(pairNode: SgNode): SgNode | null { + const children = pairNode.children(); + // pair children: [key, ":", value] + return children.length >= 3 ? children[children.length - 1] : null; +} + +/** + * Extract a single endpoint entry from a pair node like: + * `demo: { env: "DATABRICKS_SERVING_ENDPOINT", servedModel: "my-model" }` + */ +function extractEndpointEntry( + pair: SgNode, +): { alias: string; config: EndpointConfig } | null { + const children = pair.children(); + if (children.length < 3) return null; + + // Get alias name (the key) + const keyNode = children[0]; + const alias = + keyNode.kind() === "property_identifier" + ? keyNode.text() + : keyNode.kind() === "string" + ? keyNode.text().replace(/^['"]|['"]$/g, "") + : null; + + if (!alias) return null; + + // Get the value (should be an object like { env: "..." }) + const valueNode = children[children.length - 1]; + if (valueNode.kind() !== "object") return null; + + // Extract env field + const envPair = findProperty(valueNode, "env"); + if (!envPair) return null; + + const envValue = getPropertyValue(envPair); + if (!envValue || envValue.kind() !== "string") return null; + + const env = envValue.text().replace(/^['"]|['"]$/g, ""); + + // Extract optional servedModel field + const config: EndpointConfig = { env }; + const servedModelPair = findProperty(valueNode, "servedModel"); + if (servedModelPair) { + const servedModelValue = getPropertyValue(servedModelPair); + if (servedModelValue?.kind() === "string") { + config.servedModel = servedModelValue.text().replace(/^['"]|['"]$/g, ""); + } + } + + return { alias, config }; +} diff --git a/packages/appkit/src/type-generator/serving/tests/cache.test.ts b/packages/appkit/src/type-generator/serving/tests/cache.test.ts new file mode 100644 index 00000000..1c0ab21c --- /dev/null +++ b/packages/appkit/src/type-generator/serving/tests/cache.test.ts @@ -0,0 +1,107 @@ +import fs from "node:fs/promises"; +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { + CACHE_VERSION, + hashSchema, + loadServingCache, + type ServingCache, + saveServingCache, +} from "../cache"; + +vi.mock("node:fs/promises"); + +describe("serving cache", () => { + beforeEach(() => { + vi.mocked(fs.mkdir).mockResolvedValue(undefined); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + describe("hashSchema", () => { + test("returns consistent SHA256 hash", () => { + const hash1 = hashSchema('{"openapi": "3.1.0"}'); + const hash2 = hashSchema('{"openapi": "3.1.0"}'); + expect(hash1).toBe(hash2); + expect(hash1).toHaveLength(64); // SHA256 hex + }); + + test("different inputs produce different hashes", () => { + const hash1 = hashSchema('{"a": 1}'); + const hash2 = hashSchema('{"a": 2}'); + expect(hash1).not.toBe(hash2); + }); + }); + + describe("loadServingCache", () => { + test("returns empty cache when file does not exist", async () => { + vi.mocked(fs.readFile).mockRejectedValue( + Object.assign(new Error("ENOENT"), { code: "ENOENT" }), + ); + + const cache = await loadServingCache(); + expect(cache).toEqual({ version: CACHE_VERSION, endpoints: {} }); + }); + + test("returns parsed cache when file exists with correct version", async () => { + const cached: ServingCache = { + version: CACHE_VERSION, + endpoints: { + llm: { + hash: "abc", + requestType: "{ messages: string[] }", + responseType: "{ model: string }", + chunkType: null, + }, + }, + }; + vi.mocked(fs.readFile).mockResolvedValue(JSON.stringify(cached)); + + const cache = await loadServingCache(); + expect(cache).toEqual(cached); + }); + + test("flushes cache when version mismatches", async () => { + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ version: "0", endpoints: { old: {} } }), + ); + + const cache = await loadServingCache(); + expect(cache).toEqual({ version: CACHE_VERSION, endpoints: {} }); + }); + + test("flushes cache when file is corrupted", async () => { + vi.mocked(fs.readFile).mockResolvedValue("not json"); + + const cache = await loadServingCache(); + expect(cache).toEqual({ version: CACHE_VERSION, endpoints: {} }); + }); + }); + + describe("saveServingCache", () => { + test("writes cache to file", async () => { + vi.mocked(fs.writeFile).mockResolvedValue(); + + const cache: ServingCache = { + version: CACHE_VERSION, + endpoints: { + test: { + hash: "xyz", + requestType: "{}", + responseType: "{}", + chunkType: null, + }, + }, + }; + + await saveServingCache(cache); + + expect(fs.writeFile).toHaveBeenCalledWith( + expect.stringContaining(".appkit-serving-types-cache.json"), + JSON.stringify(cache, null, 2), + "utf8", + ); + }); + }); +}); diff --git a/packages/appkit/src/type-generator/serving/tests/converter.test.ts b/packages/appkit/src/type-generator/serving/tests/converter.test.ts new file mode 100644 index 00000000..ca794fb3 --- /dev/null +++ b/packages/appkit/src/type-generator/serving/tests/converter.test.ts @@ -0,0 +1,278 @@ +import { describe, expect, test } from "vitest"; +import { + convertRequestSchema, + convertResponseSchema, + deriveChunkType, +} from "../converter"; +import type { OpenApiOperation, OpenApiSchema } from "../fetcher"; + +function makeOperation( + requestProps: Record, + responseProps?: Record, + required?: string[], +): OpenApiOperation { + return { + requestBody: { + content: { + "application/json": { + schema: { + type: "object", + properties: requestProps, + required, + }, + }, + }, + }, + responses: responseProps + ? { + "200": { + content: { + "application/json": { + schema: { + type: "object", + properties: responseProps, + }, + }, + }, + }, + } + : undefined, + }; +} + +describe("converter", () => { + describe("convertRequestSchema", () => { + test("converts string type", () => { + const op = makeOperation({ name: { type: "string" } }); + const result = convertRequestSchema(op); + expect(result).toContain("name?: string;"); + }); + + test("converts integer type to number", () => { + const op = makeOperation({ count: { type: "integer" } }); + expect(convertRequestSchema(op)).toContain("count?: number;"); + }); + + test("converts number type", () => { + const op = makeOperation({ + temp: { type: "number", format: "double" }, + }); + expect(convertRequestSchema(op)).toContain("temp?: number;"); + }); + + test("converts boolean type", () => { + const op = makeOperation({ flag: { type: "boolean" } }); + expect(convertRequestSchema(op)).toContain("flag?: boolean;"); + }); + + test("converts enum to string literal union", () => { + const op = makeOperation({ + role: { type: "string", enum: ["user", "assistant"] }, + }); + const result = convertRequestSchema(op); + expect(result).toContain('"user" | "assistant"'); + }); + + test("converts array type", () => { + const op = makeOperation({ + items: { type: "array", items: { type: "string" } }, + }); + expect(convertRequestSchema(op)).toContain("items?: string[];"); + }); + + test("converts nested object", () => { + const op = makeOperation({ + messages: { + type: "array", + items: { + type: "object", + properties: { + role: { type: "string" }, + content: { type: "string" }, + }, + }, + }, + }); + const result = convertRequestSchema(op); + expect(result).toContain("role?: string;"); + expect(result).toContain("content?: string;"); + }); + + test("handles nullable properties", () => { + const op = makeOperation({ + temp: { type: "number", nullable: true }, + }); + expect(convertRequestSchema(op)).toContain("temp?: number | null;"); + }); + + test("handles oneOf union types", () => { + const op = makeOperation({ + stop: { + oneOf: [ + { type: "string" }, + { type: "array", items: { type: "string" } }, + ], + }, + }); + const result = convertRequestSchema(op); + expect(result).toContain("string | string[]"); + }); + + test("strips stream property from request", () => { + const op = makeOperation({ + messages: { type: "array", items: { type: "string" } }, + stream: { type: "boolean", nullable: true }, + temperature: { type: "number" }, + }); + const result = convertRequestSchema(op); + expect(result).not.toContain("stream"); + expect(result).toContain("messages"); + expect(result).toContain("temperature"); + }); + + test("marks required properties without ?", () => { + const op = makeOperation( + { + messages: { type: "array", items: { type: "string" } }, + temperature: { type: "number" }, + }, + undefined, + ["messages"], + ); + const result = convertRequestSchema(op); + expect(result).toContain("messages: string[];"); + expect(result).toContain("temperature?: number;"); + }); + + test("returns Record for missing schema", () => { + const op: OpenApiOperation = {}; + expect(convertRequestSchema(op)).toBe("Record"); + }); + }); + + describe("convertResponseSchema", () => { + test("converts response schema", () => { + const op = makeOperation( + {}, + { + model: { type: "string" }, + id: { type: "string" }, + }, + ); + const result = convertResponseSchema(op); + expect(result).toContain("model?: string;"); + expect(result).toContain("id?: string;"); + }); + + test("returns unknown for missing response", () => { + const op: OpenApiOperation = {}; + expect(convertResponseSchema(op)).toBe("unknown"); + }); + }); + + describe("deriveChunkType", () => { + test("derives chunk type from OpenAI-compatible response", () => { + const op: OpenApiOperation = { + responses: { + "200": { + content: { + "application/json": { + schema: { + type: "object", + properties: { + model: { type: "string" }, + choices: { + type: "array", + items: { + type: "object", + properties: { + index: { type: "integer" }, + message: { + type: "object", + properties: { + role: { + type: "string", + enum: ["user", "assistant"], + }, + content: { type: "string" }, + }, + }, + finish_reason: { type: "string" }, + }, + }, + }, + usage: { + type: "object", + properties: { + prompt_tokens: { type: "integer" }, + }, + nullable: true, + }, + id: { type: "string" }, + }, + }, + }, + }, + }, + }, + }; + + const result = deriveChunkType(op); + expect(result).not.toBeNull(); + // Should have delta instead of message + expect(result).toContain("delta"); + expect(result).not.toContain("message"); + // Should make finish_reason nullable + expect(result).toContain("finish_reason"); + expect(result).toContain("| null"); + // Should drop usage + expect(result).not.toContain("usage"); + // Should keep model and id + expect(result).toContain("model"); + expect(result).toContain("id"); + }); + + test("returns null for non-OpenAI response (no choices)", () => { + const op = makeOperation( + {}, + { + predictions: { type: "array", items: { type: "number" } }, + }, + ); + expect(deriveChunkType(op)).toBeNull(); + }); + + test("returns null for choices without message", () => { + const op: OpenApiOperation = { + responses: { + "200": { + content: { + "application/json": { + schema: { + type: "object", + properties: { + choices: { + type: "array", + items: { + type: "object", + properties: { + score: { type: "number" }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }; + expect(deriveChunkType(op)).toBeNull(); + }); + + test("returns null for missing response", () => { + const op: OpenApiOperation = {}; + expect(deriveChunkType(op)).toBeNull(); + }); + }); +}); diff --git a/packages/appkit/src/type-generator/serving/tests/fetcher.test.ts b/packages/appkit/src/type-generator/serving/tests/fetcher.test.ts new file mode 100644 index 00000000..802540b0 --- /dev/null +++ b/packages/appkit/src/type-generator/serving/tests/fetcher.test.ts @@ -0,0 +1,209 @@ +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { fetchOpenApiSchema } from "../fetcher"; + +const mockAuthenticate = vi.fn(async () => {}); + +function createMockClient(host?: string) { + return { + config: { + host, + authenticate: mockAuthenticate, + }, + } as any; +} + +function makeValidSpec( + paths: Record = { "/invocations": { post: {} } }, +) { + return { + openapi: "3.0.0", + info: { title: "test", version: "1" }, + paths, + }; +} + +describe("fetchOpenApiSchema", () => { + beforeEach(() => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify(makeValidSpec()), { status: 200 }), + ); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + test("returns null when host is not configured", async () => { + const result = await fetchOpenApiSchema(createMockClient(undefined), "ep"); + expect(result).toBeNull(); + }); + + test("returns null on HTTP 404", async () => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response("Not found", { status: 404 }), + ); + + const result = await fetchOpenApiSchema( + createMockClient("https://host.databricks.com"), + "my-endpoint", + ); + expect(result).toBeNull(); + }); + + test("returns null on HTTP 403", async () => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response("Forbidden", { status: 403 }), + ); + + const result = await fetchOpenApiSchema( + createMockClient("https://host.databricks.com"), + "my-endpoint", + ); + expect(result).toBeNull(); + }); + + test("returns null on generic error status", async () => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response("Server error", { status: 500 }), + ); + + const result = await fetchOpenApiSchema( + createMockClient("https://host.databricks.com"), + "my-endpoint", + ); + expect(result).toBeNull(); + }); + + test("returns null on timeout (AbortError)", async () => { + vi.spyOn(globalThis, "fetch").mockRejectedValue( + Object.assign(new Error("The operation was aborted"), { + name: "AbortError", + }), + ); + + const result = await fetchOpenApiSchema( + createMockClient("https://host.databricks.com"), + "my-endpoint", + ); + expect(result).toBeNull(); + }); + + test("returns null on network error", async () => { + vi.spyOn(globalThis, "fetch").mockRejectedValue(new Error("fetch failed")); + + const result = await fetchOpenApiSchema( + createMockClient("https://host.databricks.com"), + "my-endpoint", + ); + expect(result).toBeNull(); + }); + + test("returns spec and pathKey for valid response", async () => { + const spec = makeValidSpec({ + "/serving-endpoints/ep/invocations": { post: { requestBody: {} } }, + }); + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify(spec), { status: 200 }), + ); + + const result = await fetchOpenApiSchema( + createMockClient("https://host.databricks.com"), + "ep", + ); + expect(result).not.toBeNull(); + expect(result?.pathKey).toBe("/serving-endpoints/ep/invocations"); + expect(result?.spec.openapi).toBe("3.0.0"); + }); + + test("matches servedModel path when provided", async () => { + const spec = makeValidSpec({ + "/serving-endpoints/ep/served-models/gpt4/invocations": { post: {} }, + "/serving-endpoints/ep/invocations": { post: {} }, + }); + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify(spec), { status: 200 }), + ); + + const result = await fetchOpenApiSchema( + createMockClient("https://host.databricks.com"), + "ep", + "gpt4", + ); + expect(result?.pathKey).toBe( + "/serving-endpoints/ep/served-models/gpt4/invocations", + ); + }); + + test("falls back to first path when servedModel not found", async () => { + const spec = makeValidSpec({ + "/serving-endpoints/ep/invocations": { post: {} }, + }); + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify(spec), { status: 200 }), + ); + + const result = await fetchOpenApiSchema( + createMockClient("https://host.databricks.com"), + "ep", + "nonexistent-model", + ); + expect(result?.pathKey).toBe("/serving-endpoints/ep/invocations"); + }); + + test("returns null for invalid spec structure (missing paths)", async () => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify({ openapi: "3.0.0", info: {} }), { + status: 200, + }), + ); + + const result = await fetchOpenApiSchema( + createMockClient("https://host.databricks.com"), + "ep", + ); + expect(result).toBeNull(); + }); + + test("returns null when paths object is empty", async () => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify(makeValidSpec({})), { status: 200 }), + ); + + const result = await fetchOpenApiSchema( + createMockClient("https://host.databricks.com"), + "ep", + ); + expect(result).toBeNull(); + }); + + test("authenticates request headers", async () => { + await fetchOpenApiSchema( + createMockClient("https://host.databricks.com"), + "ep", + ); + expect(mockAuthenticate).toHaveBeenCalledWith(expect.any(Headers)); + }); + + test("constructs correct URL with encoded endpoint name", async () => { + const fetchSpy = vi.spyOn(globalThis, "fetch"); + + await fetchOpenApiSchema( + createMockClient("https://host.databricks.com"), + "my endpoint", + ); + + expect(fetchSpy).toHaveBeenCalledWith( + expect.stringContaining("/serving-endpoints/my%20endpoint/openapi"), + expect.any(Object), + ); + }); + + test("prepends https when host lacks protocol", async () => { + const fetchSpy = vi.spyOn(globalThis, "fetch"); + + await fetchOpenApiSchema(createMockClient("host.databricks.com"), "ep"); + + const url = fetchSpy.mock.calls[0][0] as string; + expect(url.startsWith("https://")).toBe(true); + }); +}); diff --git a/packages/appkit/src/type-generator/serving/tests/generator.test.ts b/packages/appkit/src/type-generator/serving/tests/generator.test.ts new file mode 100644 index 00000000..f9d1b378 --- /dev/null +++ b/packages/appkit/src/type-generator/serving/tests/generator.test.ts @@ -0,0 +1,215 @@ +import fs from "node:fs/promises"; +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { generateServingTypes } from "../generator"; + +vi.mock("node:fs/promises"); + +// Mock cache module +vi.mock("../cache", () => ({ + CACHE_VERSION: "1", + hashSchema: vi.fn(() => "mock-hash"), + loadServingCache: vi.fn(async () => ({ version: "1", endpoints: {} })), + saveServingCache: vi.fn(async () => {}), +})); + +// Mock fetcher +const mockFetchOpenApiSchema = vi.fn(); +vi.mock("../fetcher", () => ({ + fetchOpenApiSchema: (...args: any[]) => mockFetchOpenApiSchema(...args), +})); + +// Mock WorkspaceClient +vi.mock("@databricks/sdk-experimental", () => ({ + WorkspaceClient: vi.fn(() => ({ config: {} })), +})); + +const CHAT_OPENAPI_SPEC = { + openapi: "3.1.0", + info: { title: "test", version: "1" }, + paths: { + "/served-models/llm/invocations": { + post: { + requestBody: { + content: { + "application/json": { + schema: { + type: "object", + properties: { + messages: { + type: "array", + items: { + type: "object", + properties: { + role: { type: "string" }, + content: { type: "string" }, + }, + }, + }, + temperature: { type: "number", nullable: true }, + stream: { type: "boolean", nullable: true }, + }, + }, + }, + }, + }, + responses: { + "200": { + content: { + "application/json": { + schema: { + type: "object", + properties: { + model: { type: "string" }, + choices: { + type: "array", + items: { + type: "object", + properties: { + message: { + type: "object", + properties: { + role: { type: "string" }, + content: { type: "string" }, + }, + }, + finish_reason: { type: "string" }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, +}; + +describe("generateServingTypes", () => { + const outFile = "/tmp/test-serving-types.d.ts"; + + beforeEach(() => { + vi.mocked(fs.writeFile).mockResolvedValue(); + process.env.TEST_SERVING_ENDPOINT = "my-endpoint"; + }); + + afterEach(() => { + delete process.env.TEST_SERVING_ENDPOINT; + delete process.env.DATABRICKS_SERVING_ENDPOINT; + vi.restoreAllMocks(); + }); + + test("generates .d.ts with module augmentation for a chat endpoint", async () => { + mockFetchOpenApiSchema.mockResolvedValue({ + spec: CHAT_OPENAPI_SPEC, + pathKey: "/served-models/llm/invocations", + }); + + await generateServingTypes({ + outFile, + endpoints: { llm: { env: "TEST_SERVING_ENDPOINT" } }, + noCache: true, + }); + + expect(fs.writeFile).toHaveBeenCalledWith( + outFile, + expect.any(String), + "utf-8", + ); + const output = vi.mocked(fs.writeFile).mock.calls[0][1] as string; + + // Verify module augmentation structure + expect(output).toContain("// Auto-generated by AppKit - DO NOT EDIT"); + expect(output).toContain('import "@databricks/appkit"'); + expect(output).toContain('import "@databricks/appkit-ui/react"'); + expect(output).toContain('declare module "@databricks/appkit"'); + expect(output).toContain('declare module "@databricks/appkit-ui/react"'); + expect(output).toContain("interface ServingEndpointRegistry"); + expect(output).toContain("llm:"); + expect(output).toContain("request:"); + expect(output).toContain("response:"); + expect(output).toContain("chunk:"); + }); + + test("strips stream property from generated request type", async () => { + mockFetchOpenApiSchema.mockResolvedValue({ + spec: CHAT_OPENAPI_SPEC, + pathKey: "/served-models/llm/invocations", + }); + + await generateServingTypes({ + outFile, + endpoints: { llm: { env: "TEST_SERVING_ENDPOINT" } }, + noCache: true, + }); + + const output = vi.mocked(fs.writeFile).mock.calls[0][1] as string; + // `stream` should be stripped from request type + expect(output).toContain("messages"); + expect(output).toContain("temperature"); + expect(output).not.toMatch(/\bstream\??\s*:/); + }); + + test("emits generic types when env var is not set", async () => { + delete process.env.TEST_SERVING_ENDPOINT; + + await generateServingTypes({ + outFile, + endpoints: { llm: { env: "TEST_SERVING_ENDPOINT" } }, + noCache: true, + }); + + expect(mockFetchOpenApiSchema).not.toHaveBeenCalled(); + const output = vi.mocked(fs.writeFile).mock.calls[0][1] as string; + expect(output).toContain("llm:"); + expect(output).toContain("Record"); + }); + + test("skips generation when no endpoints configured and no env var", async () => { + await generateServingTypes({ + outFile, + noCache: true, + }); + + expect(mockFetchOpenApiSchema).not.toHaveBeenCalled(); + expect(fs.writeFile).not.toHaveBeenCalled(); + }); + + test("emits generic types when schema fetch returns null", async () => { + mockFetchOpenApiSchema.mockResolvedValue(null); + + await generateServingTypes({ + outFile, + endpoints: { llm: { env: "TEST_SERVING_ENDPOINT" } }, + noCache: true, + }); + + const output = vi.mocked(fs.writeFile).mock.calls[0][1] as string; + expect(output).toContain("llm:"); + expect(output).toContain("Record"); + }); + + test("resolves default endpoint from DATABRICKS_SERVING_ENDPOINT", async () => { + process.env.DATABRICKS_SERVING_ENDPOINT = "my-default-endpoint"; + mockFetchOpenApiSchema.mockResolvedValue({ + spec: CHAT_OPENAPI_SPEC, + pathKey: "/served-models/llm/invocations", + }); + + await generateServingTypes({ + outFile, + noCache: true, + }); + + expect(mockFetchOpenApiSchema).toHaveBeenCalledWith( + expect.anything(), + "my-default-endpoint", + undefined, + ); + + const output = vi.mocked(fs.writeFile).mock.calls[0][1] as string; + expect(output).toContain("default:"); + }); +}); diff --git a/packages/appkit/src/type-generator/serving/tests/server-file-extractor.test.ts b/packages/appkit/src/type-generator/serving/tests/server-file-extractor.test.ts new file mode 100644 index 00000000..f0a94709 --- /dev/null +++ b/packages/appkit/src/type-generator/serving/tests/server-file-extractor.test.ts @@ -0,0 +1,213 @@ +import fs from "node:fs"; +import path from "node:path"; +import { afterEach, describe, expect, test, vi } from "vitest"; +import { + extractServingEndpoints, + findServerFile, +} from "../server-file-extractor"; + +describe("findServerFile", () => { + afterEach(() => { + vi.restoreAllMocks(); + }); + + test("returns server/index.ts when it exists", () => { + vi.spyOn(fs, "existsSync").mockImplementation((p) => + String(p).endsWith(path.join("server", "index.ts")), + ); + expect(findServerFile("/app")).toBe( + path.join("/app", "server", "index.ts"), + ); + }); + + test("returns server/server.ts when index.ts does not exist", () => { + vi.spyOn(fs, "existsSync").mockImplementation((p) => + String(p).endsWith(path.join("server", "server.ts")), + ); + expect(findServerFile("/app")).toBe( + path.join("/app", "server", "server.ts"), + ); + }); + + test("returns null when no server file exists", () => { + vi.spyOn(fs, "existsSync").mockReturnValue(false); + expect(findServerFile("/app")).toBeNull(); + }); +}); + +describe("extractServingEndpoints", () => { + afterEach(() => { + vi.restoreAllMocks(); + }); + + function mockServerFile(content: string) { + vi.spyOn(fs, "readFileSync").mockReturnValue(content); + } + + test("extracts inline endpoints from serving() call", () => { + mockServerFile(` +import { createApp, serving } from '@databricks/appkit'; + +createApp({ + plugins: [ + serving({ + endpoints: { + demo: { env: "DATABRICKS_SERVING_ENDPOINT" }, + second: { env: "DATABRICKS_SERVING_ENDPOINT_SECOND" }, + } + }), + ], +}); +`); + + const result = extractServingEndpoints("/app/server/index.ts"); + expect(result).toEqual({ + demo: { env: "DATABRICKS_SERVING_ENDPOINT" }, + second: { env: "DATABRICKS_SERVING_ENDPOINT_SECOND" }, + }); + }); + + test("extracts servedModel when present", () => { + mockServerFile(` +import { createApp, serving } from '@databricks/appkit'; + +createApp({ + plugins: [ + serving({ + endpoints: { + demo: { env: "DATABRICKS_SERVING_ENDPOINT", servedModel: "my-model" }, + } + }), + ], +}); +`); + + const result = extractServingEndpoints("/app/server/index.ts"); + expect(result).toEqual({ + demo: { env: "DATABRICKS_SERVING_ENDPOINT", servedModel: "my-model" }, + }); + }); + + test("returns null when serving() has no arguments", () => { + mockServerFile(` +import { createApp, serving } from '@databricks/appkit'; + +createApp({ + plugins: [serving()], +}); +`); + + const result = extractServingEndpoints("/app/server/index.ts"); + expect(result).toBeNull(); + }); + + test("returns null when serving() has config but no endpoints", () => { + mockServerFile(` +import { createApp, serving } from '@databricks/appkit'; + +createApp({ + plugins: [ + serving({ timeout: 5000 }), + ], +}); +`); + + const result = extractServingEndpoints("/app/server/index.ts"); + expect(result).toBeNull(); + }); + + test("returns null when serving() has empty config object", () => { + mockServerFile(` +import { createApp, serving } from '@databricks/appkit'; + +createApp({ + plugins: [serving({})], +}); +`); + + const result = extractServingEndpoints("/app/server/index.ts"); + expect(result).toBeNull(); + }); + + test("returns null when endpoints is a variable reference", () => { + mockServerFile(` +import { createApp, serving } from '@databricks/appkit'; + +const myEndpoints = { demo: { env: "DATABRICKS_SERVING_ENDPOINT" } }; +createApp({ + plugins: [ + serving({ endpoints: myEndpoints }), + ], +}); +`); + + const result = extractServingEndpoints("/app/server/index.ts"); + expect(result).toBeNull(); + }); + + test("returns null when no serving() call exists", () => { + mockServerFile(` +import { createApp, analytics } from '@databricks/appkit'; + +createApp({ + plugins: [analytics({})], +}); +`); + + const result = extractServingEndpoints("/app/server/index.ts"); + expect(result).toBeNull(); + }); + + test("returns null when server file cannot be read", () => { + vi.spyOn(fs, "readFileSync").mockImplementation(() => { + throw new Error("ENOENT"); + }); + + const result = extractServingEndpoints("/app/server/nonexistent.ts"); + expect(result).toBeNull(); + }); + + test("handles single-quoted env values", () => { + mockServerFile(` +import { createApp, serving } from '@databricks/appkit'; + +createApp({ + plugins: [ + serving({ + endpoints: { + demo: { env: 'DATABRICKS_SERVING_ENDPOINT' }, + } + }), + ], +}); +`); + + const result = extractServingEndpoints("/app/server/index.ts"); + expect(result).toEqual({ + demo: { env: "DATABRICKS_SERVING_ENDPOINT" }, + }); + }); + + test("handles endpoints with trailing commas", () => { + mockServerFile(` +import { createApp, serving } from '@databricks/appkit'; + +createApp({ + plugins: [ + serving({ + endpoints: { + demo: { env: "DATABRICKS_SERVING_ENDPOINT" }, + second: { env: "DATABRICKS_SERVING_ENDPOINT_SECOND" }, + }, + }), + ], +}); +`); + + const result = extractServingEndpoints("/app/server/index.ts"); + expect(result).toEqual({ + demo: { env: "DATABRICKS_SERVING_ENDPOINT" }, + second: { env: "DATABRICKS_SERVING_ENDPOINT_SECOND" }, + }); + }); +}); diff --git a/packages/appkit/src/type-generator/serving/tests/vite-plugin.test.ts b/packages/appkit/src/type-generator/serving/tests/vite-plugin.test.ts new file mode 100644 index 00000000..bcd10915 --- /dev/null +++ b/packages/appkit/src/type-generator/serving/tests/vite-plugin.test.ts @@ -0,0 +1,186 @@ +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; + +const mockGenerateServingTypes = vi.fn(async () => {}); +const mockFindServerFile = vi.fn((): string | null => null); +const mockExtractServingEndpoints = vi.fn( + (): Record | null => null, +); + +vi.mock("../generator", () => ({ + generateServingTypes: (...args: any[]) => mockGenerateServingTypes(...args), +})); + +vi.mock("../server-file-extractor", () => ({ + findServerFile: (...args: any[]) => mockFindServerFile(...args), + extractServingEndpoints: (...args: any[]) => + mockExtractServingEndpoints(...args), +})); + +import { appKitServingTypesPlugin } from "../vite-plugin"; + +describe("appKitServingTypesPlugin", () => { + const originalEnv = { ...process.env }; + + beforeEach(() => { + mockGenerateServingTypes.mockReset(); + mockFindServerFile.mockReset(); + mockExtractServingEndpoints.mockReset(); + }); + + afterEach(() => { + process.env = { ...originalEnv }; + vi.restoreAllMocks(); + }); + + describe("apply()", () => { + test("returns true when explicit endpoints provided", () => { + const plugin = appKitServingTypesPlugin({ + endpoints: { llm: { env: "LLM_ENDPOINT" } }, + }); + expect((plugin as any).apply()).toBe(true); + }); + + test("returns true when DATABRICKS_SERVING_ENDPOINT is set", () => { + process.env.DATABRICKS_SERVING_ENDPOINT = "my-endpoint"; + const plugin = appKitServingTypesPlugin(); + expect((plugin as any).apply()).toBe(true); + }); + + test("returns true when server file found in cwd", () => { + mockFindServerFile.mockReturnValueOnce("/app/server/index.ts"); + const plugin = appKitServingTypesPlugin(); + expect((plugin as any).apply()).toBe(true); + }); + + test("returns true when server file found in parent dir", () => { + mockFindServerFile + .mockReturnValueOnce(null) // cwd check + .mockReturnValueOnce("/app/server/index.ts"); // parent check + const plugin = appKitServingTypesPlugin(); + expect((plugin as any).apply()).toBe(true); + }); + + test("returns false when nothing configured", () => { + delete process.env.DATABRICKS_SERVING_ENDPOINT; + mockFindServerFile.mockReturnValue(null); + const plugin = appKitServingTypesPlugin(); + expect((plugin as any).apply()).toBe(false); + }); + }); + + describe("configResolved()", () => { + test("resolves outFile relative to config.root", async () => { + const plugin = appKitServingTypesPlugin({ + endpoints: { llm: { env: "LLM" } }, + }); + (plugin as any).configResolved({ root: "/app/client" }); + await (plugin as any).buildStart(); + + expect(mockGenerateServingTypes).toHaveBeenCalledWith( + expect.objectContaining({ + outFile: expect.stringContaining( + "/app/client/src/appKitServingTypes.d.ts", + ), + }), + ); + }); + + test("uses custom outFile when provided", async () => { + const plugin = appKitServingTypesPlugin({ + outFile: "types/serving.d.ts", + endpoints: { llm: { env: "LLM" } }, + }); + (plugin as any).configResolved({ root: "/app/client" }); + await (plugin as any).buildStart(); + + expect(mockGenerateServingTypes).toHaveBeenCalledWith( + expect.objectContaining({ + outFile: expect.stringContaining("types/serving.d.ts"), + }), + ); + }); + }); + + describe("buildStart()", () => { + test("calls generateServingTypes with explicit endpoints", async () => { + const endpoints = { llm: { env: "LLM_ENDPOINT" } }; + const plugin = appKitServingTypesPlugin({ endpoints }); + (plugin as any).configResolved({ root: "/app/client" }); + + await (plugin as any).buildStart(); + + expect(mockGenerateServingTypes).toHaveBeenCalledWith( + expect.objectContaining({ + endpoints, + noCache: false, + }), + ); + }); + + test("extracts endpoints from server file when not explicit", async () => { + const extracted = { llm: { env: "LLM_EP" } }; + mockFindServerFile.mockReturnValue("/app/server/index.ts"); + mockExtractServingEndpoints.mockReturnValue(extracted); + + const plugin = appKitServingTypesPlugin(); + (plugin as any).configResolved({ root: "/app/client" }); + await (plugin as any).buildStart(); + + expect(mockGenerateServingTypes).toHaveBeenCalledWith( + expect.objectContaining({ endpoints: extracted }), + ); + }); + + test("passes undefined endpoints when no server file found", async () => { + mockFindServerFile.mockReturnValue(null); + + const plugin = appKitServingTypesPlugin(); + (plugin as any).configResolved({ root: "/app/client" }); + await (plugin as any).buildStart(); + + expect(mockGenerateServingTypes).toHaveBeenCalledWith( + expect.objectContaining({ endpoints: undefined }), + ); + }); + + test("passes undefined when AST extraction returns null", async () => { + mockFindServerFile.mockReturnValue("/app/server/index.ts"); + mockExtractServingEndpoints.mockReturnValue(null); + + const plugin = appKitServingTypesPlugin(); + (plugin as any).configResolved({ root: "/app/client" }); + await (plugin as any).buildStart(); + + expect(mockGenerateServingTypes).toHaveBeenCalledWith( + expect.objectContaining({ endpoints: undefined }), + ); + }); + + test("swallows errors in dev mode", async () => { + process.env.NODE_ENV = "development"; + mockGenerateServingTypes.mockRejectedValue(new Error("fetch failed")); + + const plugin = appKitServingTypesPlugin({ + endpoints: { llm: { env: "LLM" } }, + }); + (plugin as any).configResolved({ root: "/app/client" }); + + // Should not throw + await expect((plugin as any).buildStart()).resolves.toBeUndefined(); + }); + + test("rethrows errors in production mode", async () => { + process.env.NODE_ENV = "production"; + mockGenerateServingTypes.mockRejectedValue(new Error("fetch failed")); + + const plugin = appKitServingTypesPlugin({ + endpoints: { llm: { env: "LLM" } }, + }); + (plugin as any).configResolved({ root: "/app/client" }); + + await expect((plugin as any).buildStart()).rejects.toThrow( + "fetch failed", + ); + }); + }); +}); diff --git a/packages/appkit/src/type-generator/serving/vite-plugin.ts b/packages/appkit/src/type-generator/serving/vite-plugin.ts new file mode 100644 index 00000000..9903a253 --- /dev/null +++ b/packages/appkit/src/type-generator/serving/vite-plugin.ts @@ -0,0 +1,109 @@ +import path from "node:path"; +import type { Plugin } from "vite"; +import { createLogger } from "../../logging/logger"; +import type { EndpointConfig } from "../../plugins/serving/types"; +import { generateServingTypes } from "./generator"; +import { + extractServingEndpoints, + findServerFile, +} from "./server-file-extractor"; + +const logger = createLogger("type-generator:serving:vite-plugin"); + +interface AppKitServingTypesPluginOptions { + /** Path to the output .d.ts file (relative to client root). Default: "src/appKitServingTypes.d.ts" */ + outFile?: string; + /** Endpoint config override. If omitted, auto-discovers from the server file or falls back to DATABRICKS_SERVING_ENDPOINT env var. */ + endpoints?: Record; +} + +/** + * Vite plugin to generate TypeScript types for AppKit serving endpoints. + * Fetches OpenAPI schemas from Databricks and generates a .d.ts with + * ServingEndpointRegistry module augmentation. + * + * Endpoint discovery order: + * 1. Explicit `endpoints` option (override) + * 2. AST extraction from server file (server/index.ts or server/server.ts) + * 3. DATABRICKS_SERVING_ENDPOINT env var (single default endpoint) + */ +export function appKitServingTypesPlugin( + options?: AppKitServingTypesPluginOptions, +): Plugin { + let outFile: string; + let projectRoot: string; + + async function generate() { + try { + // Resolve endpoints: explicit option > server file AST > env var fallback (handled by generator) + let endpoints = options?.endpoints; + if (!endpoints) { + const serverFile = findServerFile(projectRoot); + if (serverFile) { + endpoints = extractServingEndpoints(serverFile) ?? undefined; + } + } + + await generateServingTypes({ + outFile, + endpoints, + noCache: false, + }); + } catch (error) { + if (process.env.NODE_ENV === "production") { + throw error; + } + logger.error("Error generating serving types: %O", error); + } + } + + return { + name: "appkit-serving-types", + + apply() { + // Fast checks — no AST parsing here + if (options?.endpoints && Object.keys(options.endpoints).length > 0) { + return true; + } + + if (process.env.DATABRICKS_SERVING_ENDPOINT) { + return true; + } + + // Check if a server file exists (may contain serving() config) + // Use process.cwd() for apply() since configResolved hasn't run yet + if (findServerFile(process.cwd())) { + return true; + } + + // Also check parent dir (for when cwd is client/) + const parentDir = path.resolve(process.cwd(), ".."); + if (findServerFile(parentDir)) { + return true; + } + + logger.debug( + "No serving endpoints configured. Skipping type generation.", + ); + return false; + }, + + configResolved(config) { + // Resolve project root: go up one level from Vite root (client dir) + // This handles both: + // - pnpm dev: process.cwd() is app root, config.root is client/ + // - pnpm build: process.cwd() is client/ (cd client && vite build), config.root is client/ + projectRoot = path.resolve(config.root, ".."); + outFile = path.resolve( + config.root, + options?.outFile ?? "src/appKitServingTypes.d.ts", + ); + }, + + async buildStart() { + await generate(); + }, + + // No configureServer / watcher — schemas change on endpoint redeploy, not on file edit + }; +} diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 199fcfb8..9ca11b81 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -242,6 +242,9 @@ importers: packages/appkit: dependencies: + '@ast-grep/napi': + specifier: 0.37.0 + version: 0.37.0 '@databricks/lakebase': specifier: workspace:* version: link:../lakebase From 72610a7d3de4a289aef4471a061a89c1d78fc9f6 Mon Sep 17 00:00:00 2001 From: Pawel Kosiec Date: Fri, 3 Apr 2026 12:47:22 +0200 Subject: [PATCH 04/13] fix: use structured requestKeys in cache, add stream body override - Store requestKeys[] in serving cache instead of regex-parsing TypeScript type strings in schema-filter (fragile indentation dependency) - Add overrideBody parameter to useServingStream's stream() to allow callers to pass fresh body without waiting for useMemo recomputation - Lazy-init WorkspaceClient in type generator (skip when no endpoints resolve) Signed-off-by: Pawel Kosiec --- .../__tests__/use-serving-stream.test.ts | 20 +++++ .../src/react/hooks/use-serving-stream.ts | 84 ++++++++++--------- .../src/plugins/serving/schema-filter.ts | 26 +----- .../serving/tests/schema-filter.test.ts | 30 +++++-- .../src/type-generator/serving/cache.ts | 1 + .../src/type-generator/serving/converter.ts | 10 +++ .../src/type-generator/serving/generator.ts | 14 +++- .../serving/tests/cache.test.ts | 2 + .../serving/tests/converter.test.ts | 30 +++++++ 9 files changed, 147 insertions(+), 70 deletions(-) diff --git a/packages/appkit-ui/src/react/hooks/__tests__/use-serving-stream.test.ts b/packages/appkit-ui/src/react/hooks/__tests__/use-serving-stream.test.ts index 0a1a736c..1ab0bf44 100644 --- a/packages/appkit-ui/src/react/hooks/__tests__/use-serving-stream.test.ts +++ b/packages/appkit-ui/src/react/hooks/__tests__/use-serving-stream.test.ts @@ -61,6 +61,26 @@ describe("useServingStream", () => { ); }); + test("uses override body when passed to stream()", () => { + const { result } = renderHook(() => + useServingStream({ messages: [{ role: "user", content: "old" }] }), + ); + + const overrideBody = { + messages: [{ role: "user" as const, content: "new" }], + }; + + act(() => { + result.current.stream(overrideBody); + }); + + expect(mockConnectSSE).toHaveBeenCalledWith( + expect.objectContaining({ + payload: JSON.stringify(overrideBody), + }), + ); + }); + test("uses alias in URL when provided", () => { const { result } = renderHook(() => useServingStream({ messages: [] }, { alias: "embedder" }), diff --git a/packages/appkit-ui/src/react/hooks/use-serving-stream.ts b/packages/appkit-ui/src/react/hooks/use-serving-stream.ts index 4801d94c..25cb90a7 100644 --- a/packages/appkit-ui/src/react/hooks/use-serving-stream.ts +++ b/packages/appkit-ui/src/react/hooks/use-serving-stream.ts @@ -18,9 +18,12 @@ export interface UseServingStreamOptions< onComplete?: (chunks: T[]) => void; } -export interface UseServingStreamResult { - /** Trigger the streaming invocation. */ - stream: () => void; +export interface UseServingStreamResult< + T = unknown, + TBody = Record, +> { + /** Trigger the streaming invocation. Pass an optional body override for this invocation. */ + stream: (overrideBody?: TBody) => void; /** Accumulated chunks received so far. */ chunks: T[]; /** Whether streaming is in progress. */ @@ -42,7 +45,7 @@ export interface UseServingStreamResult { export function useServingStream( body: InferServingRequest, options: UseServingStreamOptions = {} as UseServingStreamOptions, -): UseServingStreamResult> { +): UseServingStreamResult, InferServingRequest> { type TChunk = InferServingChunk; const { alias, autoStart = false, onComplete } = options; @@ -69,45 +72,50 @@ export function useServingStream( const bodyJson = JSON.stringify(body); - const stream = useCallback(() => { - // Abort any existing stream - abortControllerRef.current?.abort(); + const stream = useCallback( + (overrideBody?: InferServingRequest) => { + // Abort any existing stream + abortControllerRef.current?.abort(); - setStreaming(true); - setError(null); - setChunks([]); - chunksRef.current = []; + setStreaming(true); + setError(null); + setChunks([]); + chunksRef.current = []; - const abortController = new AbortController(); - abortControllerRef.current = abortController; + const abortController = new AbortController(); + abortControllerRef.current = abortController; - connectSSE({ - url: urlSuffix, - payload: bodyJson, - signal: abortController.signal, - onMessage: async (message) => { - if (abortController.signal.aborted) return; - try { - const parsed = JSON.parse(message.data); - - chunksRef.current = [...chunksRef.current, parsed as TChunk]; - setChunks(chunksRef.current); - } catch { - // Skip malformed messages - } - }, - onError: (err) => { + const payload = overrideBody ? JSON.stringify(overrideBody) : bodyJson; + + connectSSE({ + url: urlSuffix, + payload, + signal: abortController.signal, + onMessage: async (message) => { + if (abortController.signal.aborted) return; + try { + const parsed = JSON.parse(message.data); + + chunksRef.current = [...chunksRef.current, parsed as TChunk]; + setChunks(chunksRef.current); + } catch { + // Skip malformed messages + } + }, + onError: (err) => { + if (abortController.signal.aborted) return; + setStreaming(false); + setError(err instanceof Error ? err.message : "Streaming failed"); + }, + }).then(() => { if (abortController.signal.aborted) return; + // Stream completed setStreaming(false); - setError(err instanceof Error ? err.message : "Streaming failed"); - }, - }).then(() => { - if (abortController.signal.aborted) return; - // Stream completed - setStreaming(false); - onCompleteRef.current?.(chunksRef.current); - }); - }, [urlSuffix, bodyJson]); + onCompleteRef.current?.(chunksRef.current); + }); + }, + [urlSuffix, bodyJson], + ); useEffect(() => { if (autoStart) { diff --git a/packages/appkit/src/plugins/serving/schema-filter.ts b/packages/appkit/src/plugins/serving/schema-filter.ts index 07683ede..92a25c69 100644 --- a/packages/appkit/src/plugins/serving/schema-filter.ts +++ b/packages/appkit/src/plugins/serving/schema-filter.ts @@ -37,11 +37,8 @@ export async function loadEndpointSchemas( const cache = parsed; for (const [alias, entry] of Object.entries(cache.endpoints)) { - // Extract property keys from the requestType string - // The requestType is a TypeScript object type like "{ messages: ...; temperature: ...; }" - const keys = extractPropertyKeys(entry.requestType); - if (keys.size > 0) { - allowlists.set(alias, keys); + if (entry.requestKeys && entry.requestKeys.length > 0) { + allowlists.set(alias, new Set(entry.requestKeys)); } } } catch (err) { @@ -57,25 +54,6 @@ export async function loadEndpointSchemas( return allowlists; } -/** - * Extracts top-level property keys from a TypeScript object type string. - * Matches patterns like `key:` or `key?:` at the first nesting level. - */ -function extractPropertyKeys(typeStr: string): Set { - const keys = new Set(); - // Match property names at the top level of the object type - // Looking for patterns: ` propertyName:` or ` propertyName?:` - const propRegex = /^\s{2}(?:\/\*\*[^*]*\*\/\s*)?(\w+)\??:/gm; - for ( - let match = propRegex.exec(typeStr); - match !== null; - match = propRegex.exec(typeStr) - ) { - keys.add(match[1]); - } - return keys; -} - /** * Filters a request body against the allowed keys for an endpoint alias. * Returns the filtered body and logs a warning for stripped params. diff --git a/packages/appkit/src/plugins/serving/tests/schema-filter.test.ts b/packages/appkit/src/plugins/serving/tests/schema-filter.test.ts index 948b47f9..4fc030d8 100644 --- a/packages/appkit/src/plugins/serving/tests/schema-filter.test.ts +++ b/packages/appkit/src/plugins/serving/tests/schema-filter.test.ts @@ -109,7 +109,7 @@ describe("schema-filter", () => { expect(result.size).toBe(0); }); - test("extracts property keys from cached types", async () => { + test("reads requestKeys from cache entries", async () => { const fs = (await import("node:fs/promises")).default; vi.mocked(fs.readFile).mockResolvedValue( JSON.stringify({ @@ -117,13 +117,10 @@ describe("schema-filter", () => { endpoints: { default: { hash: "abc", - requestType: `{ - messages: string[]; - temperature?: number | null; - max_tokens: number; -}`, + requestType: "{}", responseType: "{}", chunkType: null, + requestKeys: ["messages", "temperature", "max_tokens"], }, }, }), @@ -137,5 +134,26 @@ describe("schema-filter", () => { expect(keys?.has("temperature")).toBe(true); expect(keys?.has("max_tokens")).toBe(true); }); + + test("skips entries without requestKeys (backwards compat)", async () => { + const fs = (await import("node:fs/promises")).default; + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + version: "1", + endpoints: { + default: { + hash: "abc", + requestType: "{ messages: string[] }", + responseType: "{}", + chunkType: null, + }, + }, + }), + ); + + const result = await loadEndpointSchemas("/some/path"); + // No requestKeys → passthrough mode (no allowlist) + expect(result.size).toBe(0); + }); }); }); diff --git a/packages/appkit/src/type-generator/serving/cache.ts b/packages/appkit/src/type-generator/serving/cache.ts index 2737f117..dc9bf7e2 100644 --- a/packages/appkit/src/type-generator/serving/cache.ts +++ b/packages/appkit/src/type-generator/serving/cache.ts @@ -19,6 +19,7 @@ export interface ServingCacheEntry { requestType: string; responseType: string; chunkType: string | null; + requestKeys: string[]; } export interface ServingCache { diff --git a/packages/appkit/src/type-generator/serving/converter.ts b/packages/appkit/src/type-generator/serving/converter.ts index 1849e720..b56b0460 100644 --- a/packages/appkit/src/type-generator/serving/converter.ts +++ b/packages/appkit/src/type-generator/serving/converter.ts @@ -53,6 +53,16 @@ function schemaToTypeString(schema: OpenApiSchema, indent = 0): string { } } +/** + * Extracts the top-level property keys from the request schema. + * Strips the `stream` property (plugin-controlled). + */ +export function extractRequestKeys(operation: OpenApiOperation): string[] { + const schema = operation.requestBody?.content?.["application/json"]?.schema; + if (!schema?.properties) return []; + return Object.keys(schema.properties).filter((k) => k !== "stream"); +} + /** * Extracts and converts the request schema from an OpenAPI path operation. * Strips the `stream` property from the request type. diff --git a/packages/appkit/src/type-generator/serving/generator.ts b/packages/appkit/src/type-generator/serving/generator.ts index 44026f89..2cd88619 100644 --- a/packages/appkit/src/type-generator/serving/generator.ts +++ b/packages/appkit/src/type-generator/serving/generator.ts @@ -14,6 +14,7 @@ import { convertRequestSchema, convertResponseSchema, deriveChunkType, + extractRequestKeys, } from "./converter"; import { fetchOpenApiSchema } from "./fetcher"; @@ -51,7 +52,7 @@ export async function generateServingTypes( ? { version: CACHE_VERSION, endpoints: {} } : await loadServingCache(); - const client = new WorkspaceClient({}); + let client: WorkspaceClient | undefined; let updated = false; const registryEntries: string[] = []; @@ -80,6 +81,7 @@ export async function generateServingTypes( continue; } + client ??= new WorkspaceClient({}); const result = await fetchOpenApiSchema( client, endpointName, @@ -135,10 +137,12 @@ export async function generateServingTypes( let requestType: string; let responseType: string; let chunkType: string | null; + let requestKeys: string[]; try { requestType = convertRequestSchema(operation); responseType = convertResponseSchema(operation); chunkType = deriveChunkType(operation); + requestKeys = extractRequestKeys(operation); } catch (convErr) { logger.warn( "Schema conversion failed for '%s': %s", @@ -161,7 +165,13 @@ export async function generateServingTypes( continue; } - cache.endpoints[alias] = { hash, requestType, responseType, chunkType }; + cache.endpoints[alias] = { + hash, + requestType, + responseType, + chunkType, + requestKeys, + }; updated = true; registryEntries.push( diff --git a/packages/appkit/src/type-generator/serving/tests/cache.test.ts b/packages/appkit/src/type-generator/serving/tests/cache.test.ts index 1c0ab21c..0c99c997 100644 --- a/packages/appkit/src/type-generator/serving/tests/cache.test.ts +++ b/packages/appkit/src/type-generator/serving/tests/cache.test.ts @@ -53,6 +53,7 @@ describe("serving cache", () => { requestType: "{ messages: string[] }", responseType: "{ model: string }", chunkType: null, + requestKeys: ["messages"], }, }, }; @@ -91,6 +92,7 @@ describe("serving cache", () => { requestType: "{}", responseType: "{}", chunkType: null, + requestKeys: [], }, }, }; diff --git a/packages/appkit/src/type-generator/serving/tests/converter.test.ts b/packages/appkit/src/type-generator/serving/tests/converter.test.ts index ca794fb3..1be30738 100644 --- a/packages/appkit/src/type-generator/serving/tests/converter.test.ts +++ b/packages/appkit/src/type-generator/serving/tests/converter.test.ts @@ -3,6 +3,7 @@ import { convertRequestSchema, convertResponseSchema, deriveChunkType, + extractRequestKeys, } from "../converter"; import type { OpenApiOperation, OpenApiSchema } from "../fetcher"; @@ -275,4 +276,33 @@ describe("converter", () => { expect(deriveChunkType(op)).toBeNull(); }); }); + + describe("extractRequestKeys", () => { + test("extracts top-level property keys excluding stream", () => { + const op = makeOperation({ + messages: { type: "array", items: { type: "string" } }, + temperature: { type: "number" }, + stream: { type: "boolean", nullable: true }, + }); + expect(extractRequestKeys(op)).toEqual(["messages", "temperature"]); + }); + + test("returns empty array for missing schema", () => { + const op: OpenApiOperation = {}; + expect(extractRequestKeys(op)).toEqual([]); + }); + + test("returns empty array for schema without properties", () => { + const op: OpenApiOperation = { + requestBody: { + content: { + "application/json": { + schema: { type: "object" }, + }, + }, + }, + }; + expect(extractRequestKeys(op)).toEqual([]); + }); + }); }); From 8c4ccfd571c76a830a48944453d7765b05ed837c Mon Sep 17 00:00:00 2001 From: Pawel Kosiec Date: Fri, 3 Apr 2026 13:08:10 +0200 Subject: [PATCH 05/13] fix: clear chunks after stream completion in useServingStream Chunks persisted after onComplete, causing the streaming bubble to remain visible alongside the committed message (duplicate response). Now chunks are cleared atomically with setStreaming(false) so React batches all state updates in one render. Signed-off-by: Pawel Kosiec --- .../__tests__/use-serving-stream.test.ts | 37 +++++++++++++++++++ .../src/react/hooks/use-serving-stream.ts | 4 +- 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/packages/appkit-ui/src/react/hooks/__tests__/use-serving-stream.test.ts b/packages/appkit-ui/src/react/hooks/__tests__/use-serving-stream.test.ts index 1ab0bf44..ecc00e9f 100644 --- a/packages/appkit-ui/src/react/hooks/__tests__/use-serving-stream.test.ts +++ b/packages/appkit-ui/src/react/hooks/__tests__/use-serving-stream.test.ts @@ -288,4 +288,41 @@ describe("useServingStream", () => { expect(onComplete).toHaveBeenCalledWith([{ id: 1 }, { id: 2 }]); }); + + test("clears chunks after stream completes", async () => { + // Use a controllable mock so stream doesn't auto-resolve + mockConnectSSE.mockImplementationOnce((opts: any) => { + capturedCallbacks = { + onMessage: opts.onMessage, + onError: opts.onError, + signal: opts.signal, + }; + return new Promise((resolve) => { + resolveStream = resolve; + }); + }); + + const { result } = renderHook(() => useServingStream({ messages: [] })); + + act(() => { + result.current.stream(); + }); + + // Send a chunk + act(() => { + capturedCallbacks.onMessage?.({ data: JSON.stringify({ id: 1 }) }); + }); + + expect(result.current.chunks).toEqual([{ id: 1 }]); + + // Complete the stream + await act(async () => { + resolveStream?.(); + await new Promise((r) => setTimeout(r, 0)); + }); + + // Chunks should be cleared after completion + expect(result.current.chunks).toEqual([]); + expect(result.current.streaming).toBe(false); + }); }); diff --git a/packages/appkit-ui/src/react/hooks/use-serving-stream.ts b/packages/appkit-ui/src/react/hooks/use-serving-stream.ts index 25cb90a7..d34b5559 100644 --- a/packages/appkit-ui/src/react/hooks/use-serving-stream.ts +++ b/packages/appkit-ui/src/react/hooks/use-serving-stream.ts @@ -109,9 +109,11 @@ export function useServingStream( }, }).then(() => { if (abortController.signal.aborted) return; - // Stream completed + // Stream completed — let onComplete consume chunks, then clear them setStreaming(false); onCompleteRef.current?.(chunksRef.current); + chunksRef.current = []; + setChunks([]); }); }, [urlSuffix, bodyJson], From 3a4c52d4eb5ab4741671f94ef4f6b4d530f5b607 Mon Sep 17 00:00:00 2001 From: Pawel Kosiec Date: Fri, 3 Apr 2026 13:18:21 +0200 Subject: [PATCH 06/13] fix: revert chunk-clearing in useServingStream completion MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Clearing chunks in the hook's .then() handler caused a race with React batching — chunks were empty before the component could commit them. Let consumers decide when to clear via reset() instead. Signed-off-by: Pawel Kosiec --- .../__tests__/use-serving-stream.test.ts | 37 ------------------- .../src/react/hooks/use-serving-stream.ts | 4 +- 2 files changed, 1 insertion(+), 40 deletions(-) diff --git a/packages/appkit-ui/src/react/hooks/__tests__/use-serving-stream.test.ts b/packages/appkit-ui/src/react/hooks/__tests__/use-serving-stream.test.ts index ecc00e9f..1ab0bf44 100644 --- a/packages/appkit-ui/src/react/hooks/__tests__/use-serving-stream.test.ts +++ b/packages/appkit-ui/src/react/hooks/__tests__/use-serving-stream.test.ts @@ -288,41 +288,4 @@ describe("useServingStream", () => { expect(onComplete).toHaveBeenCalledWith([{ id: 1 }, { id: 2 }]); }); - - test("clears chunks after stream completes", async () => { - // Use a controllable mock so stream doesn't auto-resolve - mockConnectSSE.mockImplementationOnce((opts: any) => { - capturedCallbacks = { - onMessage: opts.onMessage, - onError: opts.onError, - signal: opts.signal, - }; - return new Promise((resolve) => { - resolveStream = resolve; - }); - }); - - const { result } = renderHook(() => useServingStream({ messages: [] })); - - act(() => { - result.current.stream(); - }); - - // Send a chunk - act(() => { - capturedCallbacks.onMessage?.({ data: JSON.stringify({ id: 1 }) }); - }); - - expect(result.current.chunks).toEqual([{ id: 1 }]); - - // Complete the stream - await act(async () => { - resolveStream?.(); - await new Promise((r) => setTimeout(r, 0)); - }); - - // Chunks should be cleared after completion - expect(result.current.chunks).toEqual([]); - expect(result.current.streaming).toBe(false); - }); }); diff --git a/packages/appkit-ui/src/react/hooks/use-serving-stream.ts b/packages/appkit-ui/src/react/hooks/use-serving-stream.ts index d34b5559..25cb90a7 100644 --- a/packages/appkit-ui/src/react/hooks/use-serving-stream.ts +++ b/packages/appkit-ui/src/react/hooks/use-serving-stream.ts @@ -109,11 +109,9 @@ export function useServingStream( }, }).then(() => { if (abortController.signal.aborted) return; - // Stream completed — let onComplete consume chunks, then clear them + // Stream completed setStreaming(false); onCompleteRef.current?.(chunksRef.current); - chunksRef.current = []; - setChunks([]); }); }, [urlSuffix, bodyJson], From 86870670b5c2d03db70dd9efbe917d52eaa70d46 Mon Sep 17 00:00:00 2001 From: Pawel Kosiec Date: Fri, 3 Apr 2026 14:21:02 +0200 Subject: [PATCH 07/13] fix: add catch handler to connectSSE promise in useServingStream Without a .catch(), if connectSSE rejects the promise is unhandled and setStreaming(false) never fires, leaving the hook in a broken state. This matches the pattern used by the genie chat hook. Signed-off-by: Pawel Kosiec --- .../src/react/hooks/use-serving-stream.ts | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/packages/appkit-ui/src/react/hooks/use-serving-stream.ts b/packages/appkit-ui/src/react/hooks/use-serving-stream.ts index 25cb90a7..f0bb7bf2 100644 --- a/packages/appkit-ui/src/react/hooks/use-serving-stream.ts +++ b/packages/appkit-ui/src/react/hooks/use-serving-stream.ts @@ -107,12 +107,18 @@ export function useServingStream( setStreaming(false); setError(err instanceof Error ? err.message : "Streaming failed"); }, - }).then(() => { - if (abortController.signal.aborted) return; - // Stream completed - setStreaming(false); - onCompleteRef.current?.(chunksRef.current); - }); + }) + .then(() => { + if (abortController.signal.aborted) return; + // Stream completed + setStreaming(false); + onCompleteRef.current?.(chunksRef.current); + }) + .catch(() => { + if (abortController.signal.aborted) return; + setStreaming(false); + setError("Connection error"); + }); }, [urlSuffix, bodyJson], ); From 1d503def4b6a0235e9d8f50ae64f509bea4b44cc Mon Sep 17 00:00:00 2001 From: Pawel Kosiec Date: Fri, 3 Apr 2026 15:29:35 +0200 Subject: [PATCH 08/13] fix: add overrideBody parameter to useServingInvoke The invoke callback is recreated whenever body changes (via useCallback deps), which triggers the useEffect cleanup that aborts in-flight requests. Adding overrideBody allows callers to use a stable body while passing the real payload per-invocation, matching useServingStream. Signed-off-by: Pawel Kosiec --- .../src/react/hooks/use-serving-invoke.ts | 84 ++++++++++--------- 1 file changed, 46 insertions(+), 38 deletions(-) diff --git a/packages/appkit-ui/src/react/hooks/use-serving-invoke.ts b/packages/appkit-ui/src/react/hooks/use-serving-invoke.ts index 343a5e71..8e80e82e 100644 --- a/packages/appkit-ui/src/react/hooks/use-serving-invoke.ts +++ b/packages/appkit-ui/src/react/hooks/use-serving-invoke.ts @@ -14,9 +14,12 @@ export interface UseServingInvokeOptions< autoStart?: boolean; } -export interface UseServingInvokeResult { - /** Trigger the invocation. Returns the response data, or null on error/abort. */ - invoke: () => Promise; +export interface UseServingInvokeResult< + T = unknown, + TBody = Record, +> { + /** Trigger the invocation. Pass an optional body override for this invocation. */ + invoke: (overrideBody?: TBody) => Promise; /** Response data, null until loaded. */ data: T | null; /** Whether a request is in progress. */ @@ -35,7 +38,7 @@ export interface UseServingInvokeResult { export function useServingInvoke( body: InferServingRequest, options: UseServingInvokeOptions = {} as UseServingInvokeOptions, -): UseServingInvokeResult> { +): UseServingInvokeResult, InferServingRequest> { type TResponse = InferServingResponse; const { alias, autoStart = false } = options; @@ -50,44 +53,49 @@ export function useServingInvoke( const bodyJson = JSON.stringify(body); - const invoke = useCallback((): Promise => { - if (abortControllerRef.current) { - abortControllerRef.current.abort(); - } + const invoke = useCallback( + (overrideBody?: InferServingRequest): Promise => { + if (abortControllerRef.current) { + abortControllerRef.current.abort(); + } - setLoading(true); - setError(null); - setData(null); + setLoading(true); + setError(null); + setData(null); - const abortController = new AbortController(); - abortControllerRef.current = abortController; + const abortController = new AbortController(); + abortControllerRef.current = abortController; - return fetch(urlSuffix, { - method: "POST", - headers: { "Content-Type": "application/json" }, - body: bodyJson, - signal: abortController.signal, - }) - .then(async (res) => { - if (!res.ok) { - const errorBody = await res.json().catch(() => null); - throw new Error(errorBody?.error || `HTTP ${res.status}`); - } - return res.json(); - }) - .then((result: TResponse) => { - if (abortController.signal.aborted) return null; - setData(result); - setLoading(false); - return result; + const payload = overrideBody ? JSON.stringify(overrideBody) : bodyJson; + + return fetch(urlSuffix, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: payload, + signal: abortController.signal, }) - .catch((err: Error) => { - if (abortController.signal.aborted) return null; - setError(err.message || "Request failed"); - setLoading(false); - return null; - }); - }, [urlSuffix, bodyJson]); + .then(async (res) => { + if (!res.ok) { + const errorBody = await res.json().catch(() => null); + throw new Error(errorBody?.error || `HTTP ${res.status}`); + } + return res.json(); + }) + .then((result: TResponse) => { + if (abortController.signal.aborted) return null; + setData(result); + setLoading(false); + return result; + }) + .catch((err: Error) => { + if (abortController.signal.aborted) return null; + setError(err.message || "Request failed"); + setLoading(false); + return null; + }); + }, + [urlSuffix, bodyJson], + ); useEffect(() => { if (autoStart) { From 380866fe4b029493ed9ce46df8e53b4e20fad63c Mon Sep 17 00:00:00 2001 From: Pawel Kosiec Date: Fri, 3 Apr 2026 11:24:30 +0200 Subject: [PATCH 09/13] feat: add serving to dev-playground, template, and docs Integrate the Model Serving plugin into the dev-playground app with a chat-style streaming demo page. Add serving plugin to app templates and appkit init scaffolding. Include plugin documentation and auto-generated API reference docs. Signed-off-by: Pawel Kosiec --- apps/dev-playground/.env.dist | 1 + apps/dev-playground/client/.gitignore | 3 + .../client/src/routeTree.gen.ts | 21 ++ .../client/src/routes/__root.tsx | 8 + .../client/src/routes/index.tsx | 18 ++ .../client/src/routes/serving.route.tsx | 141 ++++++++++++ apps/dev-playground/client/vite.config.ts | 2 + apps/dev-playground/server/index.ts | 10 +- docs/docs/plugins/serving.md | 213 ++++++++++++++++++ template/appkit.plugins.json | 24 ++ template/client/src/App.tsx | 11 + .../client/src/pages/serving/ServingPage.tsx | 126 +++++++++++ template/client/vite.config.ts | 11 +- template/databricks.yml.tmpl | 7 +- tools/generate-app-templates.ts | 15 +- 15 files changed, 605 insertions(+), 6 deletions(-) create mode 100644 apps/dev-playground/client/src/routes/serving.route.tsx create mode 100644 docs/docs/plugins/serving.md create mode 100644 template/client/src/pages/serving/ServingPage.tsx diff --git a/apps/dev-playground/.env.dist b/apps/dev-playground/.env.dist index 23c3265a..80eda94b 100644 --- a/apps/dev-playground/.env.dist +++ b/apps/dev-playground/.env.dist @@ -9,6 +9,7 @@ OTEL_SERVICE_NAME='dev-playground' DATABRICKS_VOLUME_PLAYGROUND= DATABRICKS_VOLUME_OTHER= DATABRICKS_GENIE_SPACE_ID= +DATABRICKS_SERVING_ENDPOINT= LAKEBASE_ENDPOINT='' # Run: databricks postgres list-endpoints projects/{project-id}/branches/{branch-id} — use the `name` field from the output PGHOST= PGUSER= diff --git a/apps/dev-playground/client/.gitignore b/apps/dev-playground/client/.gitignore index a547bf36..267b28f3 100644 --- a/apps/dev-playground/client/.gitignore +++ b/apps/dev-playground/client/.gitignore @@ -12,6 +12,9 @@ dist dist-ssr *.local +# Auto-generated types (endpoint-specific, varies per developer) +src/appKitServingTypes.d.ts + # Editor directories and files .vscode/* !.vscode/extensions.json diff --git a/apps/dev-playground/client/src/routeTree.gen.ts b/apps/dev-playground/client/src/routeTree.gen.ts index c4c38d14..99ac75fc 100644 --- a/apps/dev-playground/client/src/routeTree.gen.ts +++ b/apps/dev-playground/client/src/routeTree.gen.ts @@ -12,6 +12,7 @@ import { Route as rootRouteImport } from './routes/__root' import { Route as TypeSafetyRouteRouteImport } from './routes/type-safety.route' import { Route as TelemetryRouteRouteImport } from './routes/telemetry.route' import { Route as SqlHelpersRouteRouteImport } from './routes/sql-helpers.route' +import { Route as ServingRouteRouteImport } from './routes/serving.route' import { Route as ReconnectRouteRouteImport } from './routes/reconnect.route' import { Route as LakebaseRouteRouteImport } from './routes/lakebase.route' import { Route as GenieRouteRouteImport } from './routes/genie.route' @@ -37,6 +38,11 @@ const SqlHelpersRouteRoute = SqlHelpersRouteRouteImport.update({ path: '/sql-helpers', getParentRoute: () => rootRouteImport, } as any) +const ServingRouteRoute = ServingRouteRouteImport.update({ + id: '/serving', + path: '/serving', + getParentRoute: () => rootRouteImport, +} as any) const ReconnectRouteRoute = ReconnectRouteRouteImport.update({ id: '/reconnect', path: '/reconnect', @@ -93,6 +99,7 @@ export interface FileRoutesByFullPath { '/genie': typeof GenieRouteRoute '/lakebase': typeof LakebaseRouteRoute '/reconnect': typeof ReconnectRouteRoute + '/serving': typeof ServingRouteRoute '/sql-helpers': typeof SqlHelpersRouteRoute '/telemetry': typeof TelemetryRouteRoute '/type-safety': typeof TypeSafetyRouteRoute @@ -107,6 +114,7 @@ export interface FileRoutesByTo { '/genie': typeof GenieRouteRoute '/lakebase': typeof LakebaseRouteRoute '/reconnect': typeof ReconnectRouteRoute + '/serving': typeof ServingRouteRoute '/sql-helpers': typeof SqlHelpersRouteRoute '/telemetry': typeof TelemetryRouteRoute '/type-safety': typeof TypeSafetyRouteRoute @@ -122,6 +130,7 @@ export interface FileRoutesById { '/genie': typeof GenieRouteRoute '/lakebase': typeof LakebaseRouteRoute '/reconnect': typeof ReconnectRouteRoute + '/serving': typeof ServingRouteRoute '/sql-helpers': typeof SqlHelpersRouteRoute '/telemetry': typeof TelemetryRouteRoute '/type-safety': typeof TypeSafetyRouteRoute @@ -138,6 +147,7 @@ export interface FileRouteTypes { | '/genie' | '/lakebase' | '/reconnect' + | '/serving' | '/sql-helpers' | '/telemetry' | '/type-safety' @@ -152,6 +162,7 @@ export interface FileRouteTypes { | '/genie' | '/lakebase' | '/reconnect' + | '/serving' | '/sql-helpers' | '/telemetry' | '/type-safety' @@ -166,6 +177,7 @@ export interface FileRouteTypes { | '/genie' | '/lakebase' | '/reconnect' + | '/serving' | '/sql-helpers' | '/telemetry' | '/type-safety' @@ -181,6 +193,7 @@ export interface RootRouteChildren { GenieRouteRoute: typeof GenieRouteRoute LakebaseRouteRoute: typeof LakebaseRouteRoute ReconnectRouteRoute: typeof ReconnectRouteRoute + ServingRouteRoute: typeof ServingRouteRoute SqlHelpersRouteRoute: typeof SqlHelpersRouteRoute TelemetryRouteRoute: typeof TelemetryRouteRoute TypeSafetyRouteRoute: typeof TypeSafetyRouteRoute @@ -209,6 +222,13 @@ declare module '@tanstack/react-router' { preLoaderRoute: typeof SqlHelpersRouteRouteImport parentRoute: typeof rootRouteImport } + '/serving': { + id: '/serving' + path: '/serving' + fullPath: '/serving' + preLoaderRoute: typeof ServingRouteRouteImport + parentRoute: typeof rootRouteImport + } '/reconnect': { id: '/reconnect' path: '/reconnect' @@ -285,6 +305,7 @@ const rootRouteChildren: RootRouteChildren = { GenieRouteRoute: GenieRouteRoute, LakebaseRouteRoute: LakebaseRouteRoute, ReconnectRouteRoute: ReconnectRouteRoute, + ServingRouteRoute: ServingRouteRoute, SqlHelpersRouteRoute: SqlHelpersRouteRoute, TelemetryRouteRoute: TelemetryRouteRoute, TypeSafetyRouteRoute: TypeSafetyRouteRoute, diff --git a/apps/dev-playground/client/src/routes/__root.tsx b/apps/dev-playground/client/src/routes/__root.tsx index 5cf74ce3..35a2282b 100644 --- a/apps/dev-playground/client/src/routes/__root.tsx +++ b/apps/dev-playground/client/src/routes/__root.tsx @@ -104,6 +104,14 @@ function RootComponent() { Files + + + diff --git a/apps/dev-playground/client/src/routes/index.tsx b/apps/dev-playground/client/src/routes/index.tsx index e331d93c..934b1467 100644 --- a/apps/dev-playground/client/src/routes/index.tsx +++ b/apps/dev-playground/client/src/routes/index.tsx @@ -218,6 +218,24 @@ function IndexRoute() { + + +
+

+ Model Serving +

+

+ Chat with a Databricks Model Serving endpoint using streaming + completions with real-time SSE responses. +

+ +
+
diff --git a/apps/dev-playground/client/src/routes/serving.route.tsx b/apps/dev-playground/client/src/routes/serving.route.tsx new file mode 100644 index 00000000..adbbff53 --- /dev/null +++ b/apps/dev-playground/client/src/routes/serving.route.tsx @@ -0,0 +1,141 @@ +import { useServingStream } from "@databricks/appkit-ui/react"; +import { createFileRoute } from "@tanstack/react-router"; +import { useCallback, useMemo, useState } from "react"; + +export const Route = createFileRoute("/serving")({ + component: ServingRoute, +}); + +interface Message { + id: string; + role: "user" | "assistant"; + content: string; +} + +function extractContent(chunk: unknown): string { + return ( + (chunk as { choices?: { delta?: { content?: string } }[] })?.choices?.[0] + ?.delta?.content ?? "" + ); +} + +function ServingRoute() { + const [input, setInput] = useState(""); + const [messages, setMessages] = useState([]); + + const body = useMemo( + () => ({ + messages: [...messages, { role: "user" as const, content: input }], + }), + [messages, input], + ); + + const onComplete = useCallback((chunks: unknown[]) => { + const content = chunks.map(extractContent).join(""); + if (content) { + setMessages((prev) => [ + ...prev, + { id: crypto.randomUUID(), role: "assistant", content }, + ]); + } + }, []); + + const { stream, chunks, streaming, error, reset } = useServingStream(body, { + onComplete, + }); + + const assistantContent = chunks.map(extractContent).join(""); + + function handleSubmit(e: React.FormEvent) { + e.preventDefault(); + if (!input.trim() || streaming) return; + + setMessages((prev) => [ + ...prev, + { id: crypto.randomUUID(), role: "user", content: input.trim() }, + ]); + setInput(""); + reset(); + // Trigger stream after state update + setTimeout(() => stream(), 0); + } + + return ( +
+
+
+
+

+ Model Serving +

+

+ Chat with a Databricks Model Serving endpoint. Set{" "} + + DATABRICKS_SERVING_ENDPOINT + {" "} + to enable. +

+
+ +
+ {/* Messages area */} +
+ {messages.map((msg) => ( +
+
+

{msg.content}

+
+
+ ))} + + {/* Streaming response */} + {(streaming || assistantContent) && ( +
+
+

+ {assistantContent || "..."} +

+
+
+ )} + + {error && ( +
+ Error: {error} +
+ )} +
+ + {/* Input area */} +
+ setInput(e.target.value)} + placeholder="Send a message..." + className="flex-1 rounded-md border px-3 py-2 text-sm bg-background" + disabled={streaming} + /> + +
+
+
+
+
+ ); +} diff --git a/apps/dev-playground/client/vite.config.ts b/apps/dev-playground/client/vite.config.ts index f892c62f..5f37880b 100644 --- a/apps/dev-playground/client/vite.config.ts +++ b/apps/dev-playground/client/vite.config.ts @@ -1,4 +1,5 @@ import path from "node:path"; +import { appKitServingTypesPlugin } from "@databricks/appkit"; import { tanstackRouter } from "@tanstack/router-plugin/vite"; import react from "@vitejs/plugin-react"; import { defineConfig } from "vite"; @@ -11,6 +12,7 @@ export default defineConfig({ target: "react", autoCodeSplitting: process.env.NODE_ENV !== "development", }), + appKitServingTypesPlugin(), ], server: { hmr: { diff --git a/apps/dev-playground/server/index.ts b/apps/dev-playground/server/index.ts index a4b6a2c6..af05b11f 100644 --- a/apps/dev-playground/server/index.ts +++ b/apps/dev-playground/server/index.ts @@ -1,5 +1,12 @@ import "reflect-metadata"; -import { analytics, createApp, files, genie, server } from "@databricks/appkit"; +import { + analytics, + createApp, + files, + genie, + server, + serving, +} from "@databricks/appkit"; import { WorkspaceClient } from "@databricks/sdk-experimental"; import { lakebaseExamples } from "./lakebase-examples-plugin"; import { reconnect } from "./reconnect-plugin"; @@ -26,6 +33,7 @@ createApp({ }), lakebaseExamples(), files(), + serving(), ], ...(process.env.APPKIT_E2E_TEST && { client: createMockClient() }), }).then((appkit) => { diff --git a/docs/docs/plugins/serving.md b/docs/docs/plugins/serving.md new file mode 100644 index 00000000..4b2d7a54 --- /dev/null +++ b/docs/docs/plugins/serving.md @@ -0,0 +1,213 @@ +--- +sidebar_position: 7 +--- + +# Serving plugin + +Provides an authenticated proxy to [Databricks Model Serving](https://docs.databricks.com/aws/en/machine-learning/model-serving) endpoints, with invoke and streaming support. + +**Key features:** +- Named endpoint aliases for multiple serving endpoints +- Non-streaming (`invoke`) and SSE streaming (`stream`) invocation +- Automatic OpenAPI type generation for request/response schemas +- Request body filtering based on endpoint schema +- On-behalf-of (OBO) user execution + +## Basic usage + +```ts +import { createApp, server, serving } from "@databricks/appkit"; + +await createApp({ + plugins: [ + server(), + serving(), + ], +}); +``` + +With no configuration, the plugin reads `DATABRICKS_SERVING_ENDPOINT` from the environment and registers it under the `default` alias. + +## Configuration options + +| Option | Type | Default | Description | +|--------|------|---------|-------------| +| `endpoints` | `Record` | `{ default: { env: "DATABRICKS_SERVING_ENDPOINT" } }` | Map of alias names to endpoint configs | +| `timeout` | `number` | `120000` | Request timeout in ms | + +### Endpoint aliases + +Endpoint aliases let you reference multiple serving endpoints by name: + +```ts +serving({ + endpoints: { + llm: { env: "DATABRICKS_SERVING_ENDPOINT" }, + classifier: { env: "DATABRICKS_SERVING_ENDPOINT_CLASSIFIER" }, + }, +}) +``` + +Each alias maps to an environment variable holding the actual endpoint name. If an endpoint serves multiple models, you can use `servedModel` to bypass traffic routing and target a specific model directly: + +```ts +serving({ + endpoints: { + llm: { env: "DATABRICKS_SERVING_ENDPOINT", servedModel: "llama-v2" }, + }, +}) +``` + +## Type generation + +The `appKitServingTypesPlugin()` Vite plugin generates TypeScript types from your serving endpoints' OpenAPI schemas. Add it to your `vite.config.ts`: + +```ts +import { appKitServingTypesPlugin } from "@databricks/appkit"; + +export default defineConfig({ + plugins: [ + appKitServingTypesPlugin(), + ], +}); +``` + +The plugin auto-discovers endpoint configuration from your server file (`server/index.ts` or `server/server.ts`) — no manual config passing needed. + +Generated types provide: +- **Alias autocomplete** in both backend (`AppKit.serving("alias")`) and frontend hooks (`useServingStream`, `useServingInvoke`) +- **Typed request/response/chunk** per endpoint based on OpenAPI schemas + +If an endpoint's OpenAPI schema is unavailable (not deployed, env var not set), the plugin generates generic fallback types. The endpoint is still usable — just without typed request/response. + +:::note +Endpoints that don't define a streaming response schema in their OpenAPI spec will have `chunk: unknown`. For these endpoints, use `useServingInvoke` instead of `useServingStream` — the `response` type will still be properly typed. +::: + +## Environment variables + +| Variable | Description | +|----------|-------------| +| `DATABRICKS_SERVING_ENDPOINT` | Default endpoint name (used when `endpoints` config is omitted) | + +When using named endpoints, define a custom environment variable per alias (e.g. `DATABRICKS_SERVING_ENDPOINT_CLASSIFIER`). + +## HTTP endpoints + +### Named mode (with `endpoints` config) + +- `POST /api/serving/:alias/invoke` — Non-streaming invocation +- `POST /api/serving/:alias/stream` — SSE streaming invocation + +### Default mode (no `endpoints` config) + +- `POST /api/serving/invoke` — Non-streaming invocation +- `POST /api/serving/stream` — SSE streaming invocation + +### Request format + +``` +POST /api/serving/:alias/invoke +Content-Type: application/json + +{ + "messages": [ + { "role": "user", "content": "Hello" } + ] +} +``` + +## Programmatic access + +The plugin exports `invoke` and `stream` methods for server-side use: + +```ts +const AppKit = await createApp({ + plugins: [ + server(), + serving({ + endpoints: { + llm: { env: "DATABRICKS_SERVING_ENDPOINT" }, + }, + }), + ], +}); + +// Non-streaming +const result = await AppKit.serving("llm").invoke({ + messages: [{ role: "user", content: "Hello" }], +}); + +// Streaming +for await (const chunk of AppKit.serving("llm").stream({ + messages: [{ role: "user", content: "Hello" }], +})) { + console.log(chunk); +} +``` + +## Frontend hooks + +The `@databricks/appkit-ui` package provides React hooks for serving endpoints: + +### useServingStream + +Streaming invocation via SSE: + +```tsx +import { useServingStream } from "@databricks/appkit-ui/react"; + +function ChatStream() { + const { stream, chunks, streaming, error, reset } = useServingStream( + { messages: [{ role: "user", content: "Hello" }] }, + { + alias: "llm", + onComplete: (finalChunks) => { + // Called with all accumulated chunks when the stream finishes + console.log("Stream done, got", finalChunks.length, "chunks"); + }, + }, + ); + + return ( + <> + + + {chunks.map((chunk, i) =>
{JSON.stringify(chunk)}
)} + {error &&

{error}

} + + ); +} +``` + +### useServingInvoke + +Non-streaming invocation. `invoke()` returns a promise with the response data (or `null` on error): + +```tsx +import { useServingInvoke } from "@databricks/appkit-ui/react"; + +function Classify() { + const { invoke, data, loading, error } = useServingInvoke( + { inputs: ["sample text"] }, + { alias: "classifier" }, + ); + + async function handleClick() { + const result = await invoke(); + if (result) { + console.log("Classification result:", result); + } + } + + return ( + <> + + {data &&
{JSON.stringify(data)}
} + {error &&

{error}

} + + ); +} +``` + +Both hooks accept `autoStart: true` to invoke automatically on mount. diff --git a/template/appkit.plugins.json b/template/appkit.plugins.json index cf60a8af..c21d8e80 100644 --- a/template/appkit.plugins.json +++ b/template/appkit.plugins.json @@ -149,6 +149,30 @@ "optional": [] }, "requiredByTemplate": true + }, + "serving": { + "name": "serving", + "displayName": "Model Serving Plugin", + "description": "Authenticated proxy to Databricks Model Serving endpoints", + "package": "@databricks/appkit", + "resources": { + "required": [ + { + "type": "serving_endpoint", + "alias": "Serving Endpoint", + "resourceKey": "serving-endpoint", + "description": "Model Serving endpoint for inference", + "permission": "CAN_QUERY", + "fields": { + "name": { + "env": "DATABRICKS_SERVING_ENDPOINT", + "description": "Serving endpoint name" + } + } + } + ], + "optional": [] + } } } } diff --git a/template/client/src/App.tsx b/template/client/src/App.tsx index fb4c28e6..a94bb5bc 100644 --- a/template/client/src/App.tsx +++ b/template/client/src/App.tsx @@ -17,6 +17,9 @@ import { GeniePage } from './pages/genie/GeniePage'; {{- if .plugins.files}} import { FilesPage } from './pages/files/FilesPage'; {{- end}} +{{- if .plugins.serving}} +import { ServingPage } from './pages/serving/ServingPage'; +{{- end}} const navLinkClass = ({ isActive }: { isActive: boolean }) => `px-3 py-1.5 rounded-md text-sm font-medium transition-colors ${ @@ -53,6 +56,11 @@ function Layout() { Files +{{- end}} +{{- if .plugins.serving}} + + Serving + {{- end}} @@ -80,6 +88,9 @@ const router = createBrowserRouter([ {{- end}} {{- if .plugins.files}} { path: '/files', element: }, +{{- end}} +{{- if .plugins.serving}} + { path: '/serving', element: }, {{- end}} ], }, diff --git a/template/client/src/pages/serving/ServingPage.tsx b/template/client/src/pages/serving/ServingPage.tsx new file mode 100644 index 00000000..d9363986 --- /dev/null +++ b/template/client/src/pages/serving/ServingPage.tsx @@ -0,0 +1,126 @@ +{{if .plugins.serving -}} +import { useServingInvoke } from '@databricks/appkit-ui/react'; +// For streaming endpoints (e.g. chat models), use useServingStream instead: +// import { useServingStream } from '@databricks/appkit-ui/react'; +import { useMemo, useState } from 'react'; + +interface ChatChoice { + message?: { content?: string }; +} + +interface ChatResponse { + choices?: ChatChoice[]; +} + +function extractContent(data: unknown): string { + const resp = data as ChatResponse; + return resp?.choices?.[0]?.message?.content ?? JSON.stringify(data); +} + +interface Message { + id: string; + role: 'user' | 'assistant'; + content: string; +} + +export function ServingPage() { + const [input, setInput] = useState(''); + const [messages, setMessages] = useState([]); + + const body = useMemo( + () => ({ + messages: [...messages, { role: 'user' as const, content: input }], + }), + [messages, input], + ); + + const { invoke, loading, error } = useServingInvoke(body); + // For streaming endpoints (e.g. chat models), use useServingStream: + // const { stream, chunks, streaming, error, reset } = useServingStream(body); + // Then accumulate chunks: chunks.map(c => c?.choices?.[0]?.delta?.content ?? '').join('') + + function handleSubmit(e: React.FormEvent) { + e.preventDefault(); + if (!input.trim() || loading) return; + + setMessages((prev) => [ + ...prev, + { id: crypto.randomUUID(), role: 'user', content: input.trim() }, + ]); + setInput(''); + + void invoke().then((result) => { + if (result) { + setMessages((prev) => [ + ...prev, + { id: crypto.randomUUID(), role: 'assistant', content: extractContent(result) }, + ]); + } + }); + } + + return ( +
+
+

Model Serving

+

+ Chat with a Databricks Model Serving endpoint. +

+
+ +
+
+ {messages.map((msg) => ( +
+
+

{msg.content}

+
+
+ ))} + + {loading && ( +
+
+

...

+
+
+ )} + + {error && ( +
+ Error: {error} +
+ )} +
+ +
+ setInput(e.target.value)} + placeholder="Send a message..." + className="flex-1 rounded-md border px-3 py-2 text-sm bg-background" + disabled={loading} + /> + +
+
+
+ ); +} +{{- end}} diff --git a/template/client/vite.config.ts b/template/client/vite.config.ts index b49d4055..12c1d864 100644 --- a/template/client/vite.config.ts +++ b/template/client/vite.config.ts @@ -2,11 +2,20 @@ import { defineConfig } from 'vite'; import react from '@vitejs/plugin-react'; import tailwindcss from '@tailwindcss/vite'; import path from 'node:path'; +{{- if .plugins.serving}} +import { appKitServingTypesPlugin } from '@databricks/appkit'; +{{- end}} // https://vite.dev/config/ export default defineConfig({ root: __dirname, - plugins: [react(), tailwindcss()], + plugins: [ + react(), + tailwindcss(), +{{- if .plugins.serving}} + appKitServingTypesPlugin(), +{{- end}} + ], server: { middlewareMode: true, }, diff --git a/template/databricks.yml.tmpl b/template/databricks.yml.tmpl index accf7709..77997d31 100644 --- a/template/databricks.yml.tmpl +++ b/template/databricks.yml.tmpl @@ -13,7 +13,7 @@ resources: description: "{{.appDescription}}" source_code_path: ./ -{{- if or .plugins.genie .plugins.files}} +{{- if or .plugins.genie .plugins.files .plugins.serving}} user_api_scopes: {{- if .plugins.genie}} - dashboards.genie @@ -21,8 +21,11 @@ resources: {{- if .plugins.files}} - files.files {{- end}} +{{- if .plugins.serving}} + - serving.serving-endpoints +{{- end}} {{- else}} - # Uncomment to enable on behalf of user API scopes. Available scopes: sql, dashboards.genie, files.files + # Uncomment to enable on behalf of user API scopes. Available scopes: sql, dashboards.genie, files.files, serving.serving-endpoints # user_api_scopes: # - sql {{- end}} diff --git a/tools/generate-app-templates.ts b/tools/generate-app-templates.ts index 4b029121..1eff9357 100644 --- a/tools/generate-app-templates.ts +++ b/tools/generate-app-templates.ts @@ -55,21 +55,23 @@ const FEATURE_DEPENDENCIES: Record = { files: "Volume", genie: "Genie Space", lakebase: "Database", + serving: "Serving Endpoint", }; const APP_TEMPLATES: AppTemplate[] = [ { name: "appkit-all-in-one", - features: ["analytics", "files", "genie", "lakebase"], + features: ["analytics", "files", "genie", "lakebase", "serving"], set: { "analytics.sql-warehouse.id": "placeholder", "files.files.path": "placeholder", "genie.genie-space.id": "placeholder", "lakebase.postgres.branch": "placeholder", "lakebase.postgres.database": "placeholder", + "serving.serving-endpoint.name": "placeholder", }, description: - "Full-stack Node.js app with SQL analytics dashboards, file browser, Genie AI conversations, and Lakebase Autoscaling (Postgres) CRUD", + "Full-stack Node.js app with SQL analytics dashboards, file browser, Genie AI conversations, Lakebase Autoscaling (Postgres) CRUD, and Model Serving", }, { name: "appkit-analytics", @@ -96,6 +98,15 @@ const APP_TEMPLATES: AppTemplate[] = [ }, description: "Node.js app with file browser for Databricks Volumes", }, + { + name: "appkit-serving", + features: ["serving"], + set: { + "serving.serving-endpoint.name": "placeholder", + }, + description: + "Node.js app with Databricks Model Serving endpoint integration", + }, { name: "appkit-lakebase", features: ["lakebase"], From 68f4d67d38735f64513026b114676230df346c94 Mon Sep 17 00:00:00 2001 From: Pawel Kosiec Date: Fri, 3 Apr 2026 12:49:56 +0200 Subject: [PATCH 10/13] fix: resolve race condition in serving playground chat Build the full message array synchronously in handleSubmit and pass it to stream() via the new overrideBody parameter, instead of relying on useMemo recomputation via setTimeout. Signed-off-by: Pawel Kosiec --- .../client/src/routes/serving.route.tsx | 35 ++++++++++--------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/apps/dev-playground/client/src/routes/serving.route.tsx b/apps/dev-playground/client/src/routes/serving.route.tsx index adbbff53..7b7e746a 100644 --- a/apps/dev-playground/client/src/routes/serving.route.tsx +++ b/apps/dev-playground/client/src/routes/serving.route.tsx @@ -1,6 +1,6 @@ import { useServingStream } from "@databricks/appkit-ui/react"; import { createFileRoute } from "@tanstack/react-router"; -import { useCallback, useMemo, useState } from "react"; +import { useCallback, useState } from "react"; export const Route = createFileRoute("/serving")({ component: ServingRoute, @@ -23,13 +23,6 @@ function ServingRoute() { const [input, setInput] = useState(""); const [messages, setMessages] = useState([]); - const body = useMemo( - () => ({ - messages: [...messages, { role: "user" as const, content: input }], - }), - [messages, input], - ); - const onComplete = useCallback((chunks: unknown[]) => { const content = chunks.map(extractContent).join(""); if (content) { @@ -40,9 +33,10 @@ function ServingRoute() { } }, []); - const { stream, chunks, streaming, error, reset } = useServingStream(body, { - onComplete, - }); + const { stream, chunks, streaming, error, reset } = useServingStream( + { messages: [] }, + { onComplete }, + ); const assistantContent = chunks.map(extractContent).join(""); @@ -50,14 +44,21 @@ function ServingRoute() { e.preventDefault(); if (!input.trim() || streaming) return; - setMessages((prev) => [ - ...prev, - { id: crypto.randomUUID(), role: "user", content: input.trim() }, - ]); + const userMessage: Message = { + id: crypto.randomUUID(), + role: "user", + content: input.trim(), + }; + + const fullMessages = [ + ...messages, + { role: "user" as const, content: userMessage.content }, + ]; + + setMessages((prev) => [...prev, userMessage]); setInput(""); reset(); - // Trigger stream after state update - setTimeout(() => stream(), 0); + stream({ messages: fullMessages }); } return ( From f46581e00f31263a432b290e5d22cb07e7020291 Mon Sep 17 00:00:00 2001 From: Pawel Kosiec Date: Fri, 3 Apr 2026 13:22:26 +0200 Subject: [PATCH 11/13] fix: resolve race condition in serving playground chat MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace dual-source pattern (onComplete + streaming bubble) with a useEffect that commits the assistant message on streaming→false transition, then calls reset(). Eliminates both the duplicate response and missing subsequent responses. Signed-off-by: Pawel Kosiec --- .../client/src/routes/serving.route.tsx | 36 +++++++++++-------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/apps/dev-playground/client/src/routes/serving.route.tsx b/apps/dev-playground/client/src/routes/serving.route.tsx index 7b7e746a..212be7d0 100644 --- a/apps/dev-playground/client/src/routes/serving.route.tsx +++ b/apps/dev-playground/client/src/routes/serving.route.tsx @@ -1,6 +1,6 @@ import { useServingStream } from "@databricks/appkit-ui/react"; import { createFileRoute } from "@tanstack/react-router"; -import { useCallback, useState } from "react"; +import { useEffect, useRef, useState } from "react"; export const Route = createFileRoute("/serving")({ component: ServingRoute, @@ -23,22 +23,28 @@ function ServingRoute() { const [input, setInput] = useState(""); const [messages, setMessages] = useState([]); - const onComplete = useCallback((chunks: unknown[]) => { - const content = chunks.map(extractContent).join(""); - if (content) { + const { stream, chunks, streaming, error, reset } = useServingStream({ + messages: [], + }); + + const streamingContent = chunks.map(extractContent).join(""); + + // Commit assistant message when streaming transitions from true → false + const prevStreamingRef = useRef(false); + useEffect(() => { + if (prevStreamingRef.current && !streaming && streamingContent) { setMessages((prev) => [ ...prev, - { id: crypto.randomUUID(), role: "assistant", content }, + { + id: crypto.randomUUID(), + role: "assistant", + content: streamingContent, + }, ]); + reset(); } - }, []); - - const { stream, chunks, streaming, error, reset } = useServingStream( - { messages: [] }, - { onComplete }, - ); - - const assistantContent = chunks.map(extractContent).join(""); + prevStreamingRef.current = streaming; + }, [streaming, streamingContent, reset]); function handleSubmit(e: React.FormEvent) { e.preventDefault(); @@ -99,11 +105,11 @@ function ServingRoute() { ))} {/* Streaming response */} - {(streaming || assistantContent) && ( + {streaming && (

- {assistantContent || "..."} + {streamingContent || "..."}

From 3d0944ed5e6b8579e11d367849826ef00e4d9711 Mon Sep 17 00:00:00 2001 From: Pawel Kosiec Date: Fri, 3 Apr 2026 14:21:28 +0200 Subject: [PATCH 12/13] fix: strip id field from messages sent to serving endpoint The Message interface includes an id field used as a React key. When spreading messages into the API payload, this extra field was included, which could cause the Databricks serving endpoint to reject subsequent requests. The first message worked because the messages array was empty. Signed-off-by: Pawel Kosiec --- apps/dev-playground/client/src/routes/serving.route.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/dev-playground/client/src/routes/serving.route.tsx b/apps/dev-playground/client/src/routes/serving.route.tsx index 212be7d0..770d42f4 100644 --- a/apps/dev-playground/client/src/routes/serving.route.tsx +++ b/apps/dev-playground/client/src/routes/serving.route.tsx @@ -57,7 +57,7 @@ function ServingRoute() { }; const fullMessages = [ - ...messages, + ...messages.map(({ role, content }) => ({ role, content })), { role: "user" as const, content: userMessage.content }, ]; From 7e16c81c8efa6941c64eda8fb8951eca776970e0 Mon Sep 17 00:00:00 2001 From: Pawel Kosiec Date: Fri, 3 Apr 2026 15:30:11 +0200 Subject: [PATCH 13/13] fix: use stable body and overrideBody in template serving page The template computed body via useMemo([messages, input]), causing the invoke callback to be recreated on every state change. This triggered the useEffect cleanup that aborted in-flight requests. Use a stable empty body and pass the real messages via invoke(overrideBody). Signed-off-by: Pawel Kosiec --- .../client/src/pages/serving/ServingPage.tsx | 33 ++++++++++--------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/template/client/src/pages/serving/ServingPage.tsx b/template/client/src/pages/serving/ServingPage.tsx index d9363986..b80934ba 100644 --- a/template/client/src/pages/serving/ServingPage.tsx +++ b/template/client/src/pages/serving/ServingPage.tsx @@ -2,7 +2,7 @@ import { useServingInvoke } from '@databricks/appkit-ui/react'; // For streaming endpoints (e.g. chat models), use useServingStream instead: // import { useServingStream } from '@databricks/appkit-ui/react'; -import { useMemo, useState } from 'react'; +import { useState } from 'react'; interface ChatChoice { message?: { content?: string }; @@ -27,29 +27,30 @@ export function ServingPage() { const [input, setInput] = useState(''); const [messages, setMessages] = useState([]); - const body = useMemo( - () => ({ - messages: [...messages, { role: 'user' as const, content: input }], - }), - [messages, input], - ); - - const { invoke, loading, error } = useServingInvoke(body); - // For streaming endpoints (e.g. chat models), use useServingStream: - // const { stream, chunks, streaming, error, reset } = useServingStream(body); + const { invoke, loading, error } = useServingInvoke({ messages: [] }); + // For streaming endpoints (e.g. chat models), use useServingStream instead: + // const { stream, chunks, streaming, error, reset } = useServingStream({ messages: [] }); // Then accumulate chunks: chunks.map(c => c?.choices?.[0]?.delta?.content ?? '').join('') function handleSubmit(e: React.FormEvent) { e.preventDefault(); if (!input.trim() || loading) return; - setMessages((prev) => [ - ...prev, - { id: crypto.randomUUID(), role: 'user', content: input.trim() }, - ]); + const userMessage: Message = { + id: crypto.randomUUID(), + role: 'user', + content: input.trim(), + }; + + const fullMessages = [ + ...messages.map(({ role, content }) => ({ role, content })), + { role: 'user' as const, content: userMessage.content }, + ]; + + setMessages((prev) => [...prev, userMessage]); setInput(''); - void invoke().then((result) => { + void invoke({ messages: fullMessages }).then((result) => { if (result) { setMessages((prev) => [ ...prev,