diff --git a/docs/docs/api/appkit/Interface.EndpointConfig.md b/docs/docs/api/appkit/Interface.EndpointConfig.md new file mode 100644 index 00000000..6ee94aa3 --- /dev/null +++ b/docs/docs/api/appkit/Interface.EndpointConfig.md @@ -0,0 +1,21 @@ +# Interface: EndpointConfig + +## Properties + +### env + +```ts +env: string; +``` + +Environment variable holding the endpoint name. + +*** + +### servedModel? + +```ts +optional servedModel: string; +``` + +Target a specific served model (bypasses traffic routing). diff --git a/docs/docs/api/appkit/Interface.ServingEndpointEntry.md b/docs/docs/api/appkit/Interface.ServingEndpointEntry.md new file mode 100644 index 00000000..fa054c3f --- /dev/null +++ b/docs/docs/api/appkit/Interface.ServingEndpointEntry.md @@ -0,0 +1,27 @@ +# Interface: ServingEndpointEntry + +Shape of a single registry entry. + +## Properties + +### chunk + +```ts +chunk: unknown; +``` + +*** + +### request + +```ts +request: Record; +``` + +*** + +### response + +```ts +response: unknown; +``` diff --git a/docs/docs/api/appkit/Interface.ServingEndpointRegistry.md b/docs/docs/api/appkit/Interface.ServingEndpointRegistry.md new file mode 100644 index 00000000..defe5270 --- /dev/null +++ b/docs/docs/api/appkit/Interface.ServingEndpointRegistry.md @@ -0,0 +1,5 @@ +# Interface: ServingEndpointRegistry + +Registry interface for serving endpoint type generation. +Empty by default — augmented by the Vite type generator's `.d.ts` output via module augmentation. +When populated, provides autocomplete for alias names and typed request/response/chunk per endpoint. diff --git a/docs/docs/api/appkit/TypeAlias.ServingFactory.md b/docs/docs/api/appkit/TypeAlias.ServingFactory.md new file mode 100644 index 00000000..9ccafef5 --- /dev/null +++ b/docs/docs/api/appkit/TypeAlias.ServingFactory.md @@ -0,0 +1,19 @@ +# Type Alias: ServingFactory + +```ts +type ServingFactory = keyof ServingEndpointRegistry extends never ? (alias?: string) => ServingEndpointMethods : (alias: K) => ServingEndpointMethods; +``` + +Factory function returned by `AppKit.serving`. + +This is a conditional type that adapts based on whether `ServingEndpointRegistry` +has been populated via module augmentation (generated by `appKitServingTypesPlugin()`): + +- **Registry empty (default):** `(alias?: string) => ServingEndpointMethods` — + accepts any alias string with untyped request/response/chunk. +- **Registry populated:** `(alias: K) => ServingEndpointMethods<...>` — + restricts `alias` to known endpoint keys and infers typed request/response/chunk + from the registry entry. + +Run `appKitServingTypesPlugin()` in your Vite config to generate the registry +augmentation and enable full type safety. diff --git a/docs/docs/api/appkit/index.md b/docs/docs/api/appkit/index.md index b5fb7ce0..f4685e04 100644 --- a/docs/docs/api/appkit/index.md +++ b/docs/docs/api/appkit/index.md @@ -33,6 +33,7 @@ plugin architecture, and React integration. | [BasePluginConfig](Interface.BasePluginConfig.md) | Base configuration interface for AppKit plugins | | [CacheConfig](Interface.CacheConfig.md) | Configuration for the CacheInterceptor. Controls TTL, size limits, storage backend, and probabilistic cleanup. | | [DatabaseCredential](Interface.DatabaseCredential.md) | Database credentials with OAuth token for Postgres connection | +| [EndpointConfig](Interface.EndpointConfig.md) | - | | [GenerateDatabaseCredentialRequest](Interface.GenerateDatabaseCredentialRequest.md) | Request parameters for generating database OAuth credentials | | [ITelemetry](Interface.ITelemetry.md) | Plugin-facing interface for OpenTelemetry instrumentation. Provides a thin abstraction over OpenTelemetry APIs for plugins. | | [LakebasePoolConfig](Interface.LakebasePoolConfig.md) | Configuration for creating a Lakebase connection pool | @@ -42,6 +43,8 @@ plugin architecture, and React integration. | [ResourceEntry](Interface.ResourceEntry.md) | Internal representation of a resource in the registry. Extends ResourceRequirement with resolution state and plugin ownership. | | [ResourceFieldEntry](Interface.ResourceFieldEntry.md) | Defines a single field for a resource. Each field has its own environment variable and optional description. Single-value types use one key (e.g. id); multi-value types (database, secret) use multiple (e.g. instance_name, database_name or scope, key). | | [ResourceRequirement](Interface.ResourceRequirement.md) | Declares a resource requirement for a plugin. Can be defined statically in a manifest or dynamically via getResourceRequirements(). Narrows the generated base: type → ResourceType enum, permission → ResourcePermission union. | +| [ServingEndpointEntry](Interface.ServingEndpointEntry.md) | Shape of a single registry entry. | +| [ServingEndpointRegistry](Interface.ServingEndpointRegistry.md) | Registry interface for serving endpoint type generation. Empty by default — augmented by the Vite type generator's `.d.ts` output via module augmentation. When populated, provides autocomplete for alias names and typed request/response/chunk per endpoint. | | [StreamExecutionSettings](Interface.StreamExecutionSettings.md) | Execution settings for streaming endpoints. Extends PluginExecutionSettings with SSE stream configuration. | | [TelemetryConfig](Interface.TelemetryConfig.md) | OpenTelemetry configuration for AppKit applications | | [ValidationResult](Interface.ValidationResult.md) | Result of validating all registered resources against the environment. | @@ -54,6 +57,7 @@ plugin architecture, and React integration. | [IAppRouter](TypeAlias.IAppRouter.md) | Express router type for plugin route registration | | [PluginData](TypeAlias.PluginData.md) | Tuple of plugin class, config, and name. Created by `toPlugin()` and passed to `createApp()`. | | [ResourcePermission](TypeAlias.ResourcePermission.md) | Union of all possible permission levels across all resource types. | +| [ServingFactory](TypeAlias.ServingFactory.md) | Factory function returned by `AppKit.serving`. | | [ToPlugin](TypeAlias.ToPlugin.md) | Factory function type returned by `toPlugin()`. Accepts optional config and returns a PluginData tuple. | ## Variables diff --git a/docs/docs/api/appkit/typedoc-sidebar.ts b/docs/docs/api/appkit/typedoc-sidebar.ts index 2f17b1d2..91815e3d 100644 --- a/docs/docs/api/appkit/typedoc-sidebar.ts +++ b/docs/docs/api/appkit/typedoc-sidebar.ts @@ -97,6 +97,11 @@ const typedocSidebar: SidebarsConfig = { id: "api/appkit/Interface.DatabaseCredential", label: "DatabaseCredential" }, + { + type: "doc", + id: "api/appkit/Interface.EndpointConfig", + label: "EndpointConfig" + }, { type: "doc", id: "api/appkit/Interface.GenerateDatabaseCredentialRequest", @@ -142,6 +147,16 @@ const typedocSidebar: SidebarsConfig = { id: "api/appkit/Interface.ResourceRequirement", label: "ResourceRequirement" }, + { + type: "doc", + id: "api/appkit/Interface.ServingEndpointEntry", + label: "ServingEndpointEntry" + }, + { + type: "doc", + id: "api/appkit/Interface.ServingEndpointRegistry", + label: "ServingEndpointRegistry" + }, { type: "doc", id: "api/appkit/Interface.StreamExecutionSettings", @@ -183,6 +198,11 @@ const typedocSidebar: SidebarsConfig = { id: "api/appkit/TypeAlias.ResourcePermission", label: "ResourcePermission" }, + { + type: "doc", + id: "api/appkit/TypeAlias.ServingFactory", + label: "ServingFactory" + }, { type: "doc", id: "api/appkit/TypeAlias.ToPlugin", diff --git a/docs/static/appkit-ui/styles.gen.css b/docs/static/appkit-ui/styles.gen.css index 9a9a38eb..a2192039 100644 --- a/docs/static/appkit-ui/styles.gen.css +++ b/docs/static/appkit-ui/styles.gen.css @@ -831,9 +831,6 @@ .max-w-\[calc\(100\%-2rem\)\] { max-width: calc(100% - 2rem); } - .max-w-full { - max-width: 100%; - } .max-w-max { max-width: max-content; } @@ -4514,6 +4511,11 @@ width: calc(var(--spacing) * 5); } } + .\[\&_\[data-slot\=scroll-area-viewport\]\>div\]\:\!block { + & [data-slot=scroll-area-viewport]>div { + display: block !important; + } + } .\[\&_a\]\:underline { & a { text-decoration-line: underline; @@ -4637,11 +4639,26 @@ color: var(--muted-foreground); } } + .\[\&_table\]\:block { + & table { + display: block; + } + } + .\[\&_table\]\:max-w-full { + & table { + max-width: 100%; + } + } .\[\&_table\]\:border-collapse { & table { border-collapse: collapse; } } + .\[\&_table\]\:overflow-x-auto { + & table { + overflow-x: auto; + } + } .\[\&_table\]\:text-xs { & table { font-size: var(--text-xs); @@ -4851,6 +4868,11 @@ width: 100%; } } + .\[\&\>\*\]\:min-w-0 { + &>* { + min-width: calc(var(--spacing) * 0); + } + } .\[\&\>\*\]\:focus-visible\:relative { &>* { &:focus-visible { diff --git a/packages/appkit/src/connectors/serving/client.ts b/packages/appkit/src/connectors/serving/client.ts new file mode 100644 index 00000000..6254426d --- /dev/null +++ b/packages/appkit/src/connectors/serving/client.ts @@ -0,0 +1,223 @@ +import { ApiError, type WorkspaceClient } from "@databricks/sdk-experimental"; +import { createLogger } from "../../logging/logger"; +import type { ServingInvokeOptions } from "./types"; + +const logger = createLogger("connectors:serving"); + +/** + * Builds the invocation URL for a serving endpoint. + * Uses `/served-models/{model}/invocations` when servedModel is specified, + * otherwise `/serving-endpoints/{name}/invocations`. + */ +function buildInvocationUrl( + host: string, + endpointName: string, + servedModel?: string, +): string { + const base = host.startsWith("http") ? host : `https://${host}`; + const encodedName = encodeURIComponent(endpointName); + const path = servedModel + ? `/serving-endpoints/${encodedName}/served-models/${encodeURIComponent(servedModel)}/invocations` + : `/serving-endpoints/${encodedName}/invocations`; + return new URL(path, base).toString(); +} + +/** + * Maps upstream Databricks error status codes to appropriate proxy responses. + */ +function mapUpstreamError( + status: number, + body: string, + headers: Headers, +): ApiError { + const safeMessage = body.length > 500 ? `${body.slice(0, 500)}...` : body; + + let parsed: { message?: string; error?: string } = {}; + try { + parsed = JSON.parse(body); + } catch { + // body is not JSON + } + + const message = parsed.message || parsed.error || safeMessage; + + switch (true) { + case status === 400: + return new ApiError(message, "BAD_REQUEST", 400, undefined, []); + case status === 401 || status === 403: + logger.warn("Authentication failure from serving endpoint: %s", message); + return new ApiError(message, "AUTH_FAILURE", status, undefined, []); + case status === 404: + return new ApiError(message, "NOT_FOUND", 404, undefined, []); + case status === 429: { + const retryAfter = headers.get("retry-after"); + const retryMessage = retryAfter + ? `${message} (retry-after: ${retryAfter})` + : message; + return new ApiError(retryMessage, "RATE_LIMITED", 429, undefined, []); + } + case status === 503: + return new ApiError( + "Endpoint loading, retry shortly", + "SERVICE_UNAVAILABLE", + 503, + undefined, + [], + ); + case status >= 500: + return new ApiError(message, "BAD_GATEWAY", 502, undefined, []); + default: + return new ApiError(message, "UNKNOWN", status, undefined, []); + } +} + +/** + * Invokes a serving endpoint and returns the parsed JSON response. + */ +export async function invoke( + client: WorkspaceClient, + endpointName: string, + body: Record, + options?: ServingInvokeOptions, +): Promise { + const host = client.config.host; + if (!host) { + throw new Error( + "Databricks host is not configured. Set DATABRICKS_HOST or configure client.config.host.", + ); + } + + const url = buildInvocationUrl(host, endpointName, options?.servedModel); + + // Always strip `stream` from the body — the connector controls this + const { stream: _stream, ...cleanBody } = body; + + const headers = new Headers({ + "Content-Type": "application/json", + Accept: "application/json", + }); + await client.config.authenticate(headers); + + logger.debug("Invoking endpoint %s at %s", endpointName, url); + + const res = await fetch(url, { + method: "POST", + headers, + body: JSON.stringify(cleanBody), + signal: options?.signal, + }); + + if (!res.ok) { + const text = await res.text(); + throw mapUpstreamError(res.status, text, res.headers); + } + + return res.json(); +} + +/** + * Invokes a serving endpoint with streaming enabled. + * Yields parsed JSON chunks from the NDJSON SSE response. + */ +export async function* stream( + client: WorkspaceClient, + endpointName: string, + body: Record, + options?: ServingInvokeOptions, +): AsyncGenerator { + const host = client.config.host; + if (!host) { + throw new Error( + "Databricks host is not configured. Set DATABRICKS_HOST or configure client.config.host.", + ); + } + + const url = buildInvocationUrl(host, endpointName, options?.servedModel); + + // Strip any user-provided `stream` and inject `stream: true` + const { stream: _stream, ...cleanBody } = body; + const streamBody = { ...cleanBody, stream: true }; + + const headers = new Headers({ + "Content-Type": "application/json", + Accept: "text/event-stream", + }); + await client.config.authenticate(headers); + + logger.debug("Streaming from endpoint %s at %s", endpointName, url); + + const res = await fetch(url, { + method: "POST", + headers, + body: JSON.stringify(streamBody), + signal: options?.signal, + }); + + if (!res.ok) { + const text = await res.text(); + throw mapUpstreamError(res.status, text, res.headers); + } + + if (!res.body) { + throw new Error("Response body is null — streaming not supported"); + } + + const reader = res.body.getReader(); + const decoder = new TextDecoder(); + let buffer = ""; + const MAX_BUFFER_SIZE = 1024 * 1024; // 1 MB + + try { + while (true) { + if (options?.signal?.aborted) break; + + const { done, value } = await reader.read(); + if (done) break; + + buffer += decoder.decode(value, { stream: true }); + + if (buffer.length > MAX_BUFFER_SIZE) { + logger.warn( + "Stream buffer exceeded %d bytes, discarding incomplete data", + MAX_BUFFER_SIZE, + ); + buffer = ""; + } + + // Process complete lines from the buffer + const lines = buffer.split("\n"); + // Keep the last (potentially incomplete) line in the buffer + buffer = lines.pop() ?? ""; + + for (const line of lines) { + const trimmed = line.trim(); + if (!trimmed || trimmed.startsWith(":")) continue; // skip empty lines and SSE comments + if (trimmed === "data: [DONE]") return; + + if (trimmed.startsWith("data: ")) { + const jsonStr = trimmed.slice(6); + try { + yield JSON.parse(jsonStr); + } catch { + logger.warn("Failed to parse streaming chunk: %s", jsonStr); + } + } + } + } + + // Process any remaining data in the buffer + if (buffer.trim() && !options?.signal?.aborted) { + const trimmed = buffer.trim(); + if (trimmed.startsWith("data: ") && trimmed !== "data: [DONE]") { + try { + yield JSON.parse(trimmed.slice(6)); + } catch { + logger.warn("Failed to parse final streaming chunk: %s", trimmed); + } + } + } + } finally { + reader.cancel().catch(() => {}); + reader.releaseLock(); + } +} diff --git a/packages/appkit/src/connectors/serving/tests/client.test.ts b/packages/appkit/src/connectors/serving/tests/client.test.ts new file mode 100644 index 00000000..6af859ae --- /dev/null +++ b/packages/appkit/src/connectors/serving/tests/client.test.ts @@ -0,0 +1,303 @@ +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { invoke, stream } from "../client"; + +const mockAuthenticate = vi.fn(); + +function createMockClient(host = "https://test.databricks.com") { + return { + config: { + host, + authenticate: mockAuthenticate, + }, + } as any; +} + +describe("Serving Connector", () => { + beforeEach(() => { + mockAuthenticate.mockResolvedValue(undefined); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + describe("invoke", () => { + test("constructs correct URL for endpoint invocation", async () => { + const fetchSpy = vi + .spyOn(globalThis, "fetch") + .mockResolvedValue( + new Response(JSON.stringify({ result: "ok" }), { status: 200 }), + ); + + const client = createMockClient(); + await invoke(client, "my-endpoint", { messages: [] }); + + expect(fetchSpy).toHaveBeenCalledWith( + "https://test.databricks.com/serving-endpoints/my-endpoint/invocations", + expect.objectContaining({ method: "POST" }), + ); + }); + + test("constructs correct URL with servedModel override", async () => { + const fetchSpy = vi + .spyOn(globalThis, "fetch") + .mockResolvedValue( + new Response(JSON.stringify({ result: "ok" }), { status: 200 }), + ); + + const client = createMockClient(); + await invoke( + client, + "my-endpoint", + { messages: [] }, + { servedModel: "llama-v2" }, + ); + + expect(fetchSpy).toHaveBeenCalledWith( + "https://test.databricks.com/serving-endpoints/my-endpoint/served-models/llama-v2/invocations", + expect.objectContaining({ method: "POST" }), + ); + }); + + test("authenticates request headers", async () => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify({ result: "ok" }), { status: 200 }), + ); + + const client = createMockClient(); + await invoke(client, "my-endpoint", { messages: [] }); + + expect(mockAuthenticate).toHaveBeenCalledWith(expect.any(Headers)); + }); + + test("strips stream property from body", async () => { + const fetchSpy = vi + .spyOn(globalThis, "fetch") + .mockResolvedValue( + new Response(JSON.stringify({ result: "ok" }), { status: 200 }), + ); + + const client = createMockClient(); + await invoke(client, "my-endpoint", { + messages: [], + stream: true, + temperature: 0.7, + }); + + const body = JSON.parse(fetchSpy.mock.calls[0][1]?.body as string); + expect(body).toEqual({ messages: [], temperature: 0.7 }); + expect(body.stream).toBeUndefined(); + }); + + test("returns parsed JSON response", async () => { + const responseData = { choices: [{ message: { content: "Hello" } }] }; + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify(responseData), { status: 200 }), + ); + + const client = createMockClient(); + const result = await invoke(client, "my-endpoint", { messages: [] }); + + expect(result).toEqual(responseData); + }); + + test("throws ApiError on 400 response", async () => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify({ message: "Invalid params" }), { + status: 400, + }), + ); + + const client = createMockClient(); + await expect( + invoke(client, "my-endpoint", { messages: [] }), + ).rejects.toThrow("Invalid params"); + }); + + test("throws ApiError on 404 response", async () => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify({ message: "Endpoint not found" }), { + status: 404, + }), + ); + + const client = createMockClient(); + await expect( + invoke(client, "my-endpoint", { messages: [] }), + ).rejects.toThrow("Endpoint not found"); + }); + + test("maps 5xx to 502 bad gateway", async () => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify({ message: "Internal error" }), { + status: 500, + }), + ); + + const client = createMockClient(); + try { + await invoke(client, "my-endpoint", { messages: [] }); + expect.unreachable("Should have thrown"); + } catch (err: any) { + expect(err.statusCode).toBe(502); + } + }); + + test("forwards AbortSignal", async () => { + const controller = new AbortController(); + const fetchSpy = vi + .spyOn(globalThis, "fetch") + .mockResolvedValue( + new Response(JSON.stringify({ result: "ok" }), { status: 200 }), + ); + + const client = createMockClient(); + await invoke( + client, + "my-endpoint", + { messages: [] }, + { signal: controller.signal }, + ); + + expect(fetchSpy.mock.calls[0][1]?.signal).toBe(controller.signal); + }); + + test("throws when host is not configured", async () => { + const client = { + config: { + host: "", + authenticate: mockAuthenticate, + }, + } as any; + await expect( + invoke(client, "my-endpoint", { messages: [] }), + ).rejects.toThrow("Databricks host is not configured"); + }); + + test("prepends https:// to host without protocol", async () => { + const fetchSpy = vi + .spyOn(globalThis, "fetch") + .mockResolvedValue( + new Response(JSON.stringify({ result: "ok" }), { status: 200 }), + ); + + const client = createMockClient("test.databricks.com"); + await invoke(client, "my-endpoint", { messages: [] }); + + expect(fetchSpy.mock.calls[0][0]).toContain( + "https://test.databricks.com", + ); + }); + }); + + describe("stream", () => { + function createSSEResponse(chunks: string[]) { + const body = `${chunks.join("\n")}\n`; + return new Response(body, { + status: 200, + headers: { "Content-Type": "text/event-stream" }, + }); + } + + test("yields parsed NDJSON chunks", async () => { + const chunks = [ + 'data: {"choices":[{"delta":{"content":"Hello"}}]}', + 'data: {"choices":[{"delta":{"content":" world"}}]}', + "data: [DONE]", + ]; + + vi.spyOn(globalThis, "fetch").mockResolvedValue( + createSSEResponse(chunks), + ); + + const client = createMockClient(); + const results: unknown[] = []; + for await (const chunk of stream(client, "my-endpoint", { + messages: [], + })) { + results.push(chunk); + } + + expect(results).toEqual([ + { choices: [{ delta: { content: "Hello" } }] }, + { choices: [{ delta: { content: " world" } }] }, + ]); + }); + + test("injects stream: true into body", async () => { + const fetchSpy = vi + .spyOn(globalThis, "fetch") + .mockResolvedValue(createSSEResponse(["data: [DONE]"])); + + const client = createMockClient(); + // Consume the generator + for await (const _ of stream(client, "my-endpoint", { messages: [] })) { + // noop + } + + const body = JSON.parse(fetchSpy.mock.calls[0][1]?.body as string); + expect(body.stream).toBe(true); + }); + + test("strips user-provided stream and re-injects", async () => { + const fetchSpy = vi + .spyOn(globalThis, "fetch") + .mockResolvedValue(createSSEResponse(["data: [DONE]"])); + + const client = createMockClient(); + for await (const _ of stream(client, "my-endpoint", { + messages: [], + stream: false, + })) { + // noop + } + + const body = JSON.parse(fetchSpy.mock.calls[0][1]?.body as string); + expect(body.stream).toBe(true); + }); + + test("skips SSE comments and empty lines", async () => { + const chunks = [ + ": this is a comment", + "", + 'data: {"choices":[{"delta":{"content":"Hi"}}]}', + "", + "data: [DONE]", + ]; + + vi.spyOn(globalThis, "fetch").mockResolvedValue( + createSSEResponse(chunks), + ); + + const client = createMockClient(); + const results: unknown[] = []; + for await (const chunk of stream(client, "my-endpoint", { + messages: [], + })) { + results.push(chunk); + } + + expect(results).toHaveLength(1); + expect(results[0]).toEqual({ choices: [{ delta: { content: "Hi" } }] }); + }); + + test("throws on non-OK response", async () => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify({ message: "Rate limited" }), { + status: 429, + headers: { "Retry-After": "5" }, + }), + ); + + const client = createMockClient(); + try { + for await (const _ of stream(client, "my-endpoint", { messages: [] })) { + // noop + } + expect.unreachable("Should have thrown"); + } catch (err: any) { + expect(err.statusCode).toBe(429); + } + }); + }); +}); diff --git a/packages/appkit/src/connectors/serving/types.ts b/packages/appkit/src/connectors/serving/types.ts new file mode 100644 index 00000000..6dd1acba --- /dev/null +++ b/packages/appkit/src/connectors/serving/types.ts @@ -0,0 +1,4 @@ +export interface ServingInvokeOptions { + servedModel?: string; + signal?: AbortSignal; +} diff --git a/packages/appkit/src/index.ts b/packages/appkit/src/index.ts index 8db7f1d7..662a9178 100644 --- a/packages/appkit/src/index.ts +++ b/packages/appkit/src/index.ts @@ -48,7 +48,13 @@ export { } from "./errors"; // Plugin authoring export { Plugin, type ToPlugin, toPlugin } from "./plugin"; -export { analytics, files, genie, lakebase, server } from "./plugins"; +export { analytics, files, genie, lakebase, server, serving } from "./plugins"; +export type { + EndpointConfig, + ServingEndpointEntry, + ServingEndpointRegistry, + ServingFactory, +} from "./plugins/serving/types"; // Registry types and utilities for plugin manifests export type { ConfigSchema, diff --git a/packages/appkit/src/plugins/index.ts b/packages/appkit/src/plugins/index.ts index 7caa040f..4d58082f 100644 --- a/packages/appkit/src/plugins/index.ts +++ b/packages/appkit/src/plugins/index.ts @@ -3,3 +3,4 @@ export * from "./files"; export * from "./genie"; export * from "./lakebase"; export * from "./server"; +export * from "./serving"; diff --git a/packages/appkit/src/plugins/serving/defaults.ts b/packages/appkit/src/plugins/serving/defaults.ts new file mode 100644 index 00000000..1fea64c2 --- /dev/null +++ b/packages/appkit/src/plugins/serving/defaults.ts @@ -0,0 +1,26 @@ +import type { StreamExecutionSettings } from "shared"; + +export const servingInvokeDefaults = { + cache: { + enabled: false, + }, + retry: { + enabled: false, + }, + timeout: 120_000, +}; + +export const servingStreamDefaults: StreamExecutionSettings = { + default: { + cache: { + enabled: false, + }, + retry: { + enabled: false, + }, + timeout: 120_000, + }, + stream: { + bufferSize: 200, + }, +}; diff --git a/packages/appkit/src/plugins/serving/index.ts b/packages/appkit/src/plugins/serving/index.ts new file mode 100644 index 00000000..85caf33b --- /dev/null +++ b/packages/appkit/src/plugins/serving/index.ts @@ -0,0 +1,2 @@ +export * from "./serving"; +export * from "./types"; diff --git a/packages/appkit/src/plugins/serving/manifest.json b/packages/appkit/src/plugins/serving/manifest.json new file mode 100644 index 00000000..9ac0845f --- /dev/null +++ b/packages/appkit/src/plugins/serving/manifest.json @@ -0,0 +1,54 @@ +{ + "$schema": "https://databricks.github.io/appkit/schemas/plugin-manifest.schema.json", + "name": "serving", + "displayName": "Model Serving Plugin", + "description": "Authenticated proxy to Databricks Model Serving endpoints", + "resources": { + "required": [ + { + "type": "serving_endpoint", + "alias": "Serving Endpoint", + "resourceKey": "serving-endpoint", + "description": "Model Serving endpoint for inference", + "permission": "CAN_QUERY", + "fields": { + "name": { + "env": "DATABRICKS_SERVING_ENDPOINT", + "description": "Serving endpoint name" + } + } + } + ], + "optional": [] + }, + "config": { + "schema": { + "type": "object", + "properties": { + "endpoints": { + "type": "object", + "description": "Map of alias names to endpoint configurations", + "additionalProperties": { + "type": "object", + "properties": { + "env": { + "type": "string", + "description": "Environment variable holding the endpoint name" + }, + "servedModel": { + "type": "string", + "description": "Target a specific served model (bypasses traffic routing)" + } + }, + "required": ["env"] + } + }, + "timeout": { + "type": "number", + "default": 120000, + "description": "Request timeout in ms. Default: 120000 (2 min)" + } + } + } + } +} diff --git a/packages/appkit/src/plugins/serving/schema-filter.ts b/packages/appkit/src/plugins/serving/schema-filter.ts new file mode 100644 index 00000000..6e52294a --- /dev/null +++ b/packages/appkit/src/plugins/serving/schema-filter.ts @@ -0,0 +1,127 @@ +import fs from "node:fs/promises"; +import { createLogger } from "../../logging/logger"; + +const CACHE_VERSION = "1"; + +interface ServingCacheEntry { + hash: string; + requestType: string; + responseType: string; + chunkType: string | null; +} + +interface ServingCache { + version: string; + endpoints: Record; +} + +const logger = createLogger("serving:schema-filter"); + +function isValidCache(data: unknown): data is ServingCache { + return ( + typeof data === "object" && + data !== null && + "version" in data && + (data as ServingCache).version === CACHE_VERSION && + "endpoints" in data && + typeof (data as ServingCache).endpoints === "object" + ); +} + +/** + * Loads endpoint schemas from the type generation cache file. + * Returns a map of alias → allowed parameter keys. + */ +export async function loadEndpointSchemas( + cacheFile: string, +): Promise>> { + const allowlists = new Map>(); + + try { + const raw = await fs.readFile(cacheFile, "utf8"); + const parsed: unknown = JSON.parse(raw); + if (!isValidCache(parsed)) { + logger.warn("Serving types cache has invalid structure, skipping"); + return allowlists; + } + const cache = parsed; + + for (const [alias, entry] of Object.entries(cache.endpoints)) { + // Extract property keys from the requestType string + // The requestType is a TypeScript object type like "{ messages: ...; temperature: ...; }" + const keys = extractPropertyKeys(entry.requestType); + if (keys.size > 0) { + allowlists.set(alias, keys); + } + } + } catch (err) { + if ((err as NodeJS.ErrnoException).code !== "ENOENT") { + logger.warn( + "Failed to load serving types cache: %s", + (err as Error).message, + ); + } + // No cache → no filtering, passthrough mode + } + + return allowlists; +} + +/** + * Extracts top-level property keys from a TypeScript object type string. + * Matches patterns like `key:` or `key?:` at the first nesting level. + */ +function extractPropertyKeys(typeStr: string): Set { + const keys = new Set(); + // Match property names at the top level of the object type + // Looking for patterns: ` propertyName:` or ` propertyName?:` + const propRegex = /^\s{2}(?:\/\*\*[^*]*\*\/\s*)?(\w+)\??:/gm; + for ( + let match = propRegex.exec(typeStr); + match !== null; + match = propRegex.exec(typeStr) + ) { + keys.add(match[1]); + } + return keys; +} + +/** + * Filters a request body against the allowed keys for an endpoint alias. + * Returns the filtered body and logs a warning for stripped params. + * + * If no allowlist exists for the alias, returns the body unchanged (passthrough). + */ +export function filterRequestBody( + body: Record, + allowlists: Map>, + alias: string, + filterMode: "strip" | "reject" = "strip", +): Record { + const allowed = allowlists.get(alias); + if (!allowed) return body; + + const stripped: string[] = []; + const filtered: Record = {}; + + for (const [key, value] of Object.entries(body)) { + if (allowed.has(key)) { + filtered[key] = value; + } else { + stripped.push(key); + } + } + + if (stripped.length > 0) { + if (filterMode === "reject") { + throw new Error(`Unknown request parameters: ${stripped.join(", ")}`); + } + logger.warn( + "Stripped unknown params from '%s': %s", + alias, + stripped.join(", "), + ); + } + + return filtered; +} diff --git a/packages/appkit/src/plugins/serving/serving.ts b/packages/appkit/src/plugins/serving/serving.ts new file mode 100644 index 00000000..e3547bcf --- /dev/null +++ b/packages/appkit/src/plugins/serving/serving.ts @@ -0,0 +1,304 @@ +import { randomUUID } from "node:crypto"; +import path from "node:path"; +import type express from "express"; +import type { IAppRouter, StreamExecutionSettings } from "shared"; +import * as servingConnector from "../../connectors/serving/client"; +import { getWorkspaceClient } from "../../context"; +import { createLogger } from "../../logging"; +import { Plugin, toPlugin } from "../../plugin"; +import type { PluginManifest, ResourceRequirement } from "../../registry"; +import { ResourceType } from "../../registry"; +import { servingInvokeDefaults, servingStreamDefaults } from "./defaults"; +import manifest from "./manifest.json"; +import { filterRequestBody, loadEndpointSchemas } from "./schema-filter"; +import type { EndpointConfig, IServingConfig, ServingFactory } from "./types"; + +const logger = createLogger("serving"); + +class EndpointNotFoundError extends Error { + constructor(alias: string) { + super(`Unknown endpoint alias: ${alias}`); + } +} + +class EndpointNotConfiguredError extends Error { + constructor(alias: string, envVar: string) { + super( + `Endpoint '${alias}' is not configured: env var '${envVar}' is not set`, + ); + } +} + +interface ResolvedEndpoint { + name: string; + servedModel?: string; +} + +export class ServingPlugin extends Plugin { + static manifest = manifest as PluginManifest<"serving">; + + protected static description = + "Authenticated proxy to Databricks Model Serving endpoints"; + protected declare config: IServingConfig; + + private readonly endpoints: Record; + private readonly isNamedMode: boolean; + private schemaAllowlists = new Map>(); + + constructor(config: IServingConfig) { + super(config); + this.config = config; + + if (config.endpoints) { + this.endpoints = config.endpoints; + this.isNamedMode = true; + } else { + this.endpoints = { + default: { env: "DATABRICKS_SERVING_ENDPOINT" }, + }; + this.isNamedMode = false; + } + } + + async setup(): Promise { + const cacheFile = path.join( + process.cwd(), + "node_modules", + ".databricks", + "appkit", + ".appkit-serving-types-cache.json", + ); + this.schemaAllowlists = await loadEndpointSchemas(cacheFile); + if (this.schemaAllowlists.size > 0) { + logger.debug( + "Loaded schema allowlists for %d endpoint(s)", + this.schemaAllowlists.size, + ); + } + } + + static getResourceRequirements( + config: IServingConfig, + ): ResourceRequirement[] { + const endpoints = config.endpoints ?? { + default: { env: "DATABRICKS_SERVING_ENDPOINT" }, + }; + + return Object.entries(endpoints).map(([alias, endpointConfig]) => ({ + type: ResourceType.SERVING_ENDPOINT, + alias: `serving-${alias}`, + resourceKey: `serving-${alias}`, + description: `Model Serving endpoint for "${alias}" inference`, + permission: "CAN_QUERY" as const, + fields: { + name: { + env: endpointConfig.env, + description: `Serving endpoint name for "${alias}"`, + }, + }, + required: true, + })); + } + + private resolveAndFilter( + alias: string, + body: Record, + ): { endpoint: ResolvedEndpoint; filteredBody: Record } { + const config = this.endpoints[alias]; + if (!config) { + throw new EndpointNotFoundError(alias); + } + + const name = process.env[config.env]; + if (!name) { + throw new EndpointNotConfiguredError(alias, config.env); + } + + const endpoint: ResolvedEndpoint = { + name, + servedModel: config.servedModel, + }; + const filteredBody = filterRequestBody( + body, + this.schemaAllowlists, + alias, + this.config.filterMode, + ); + return { endpoint, filteredBody }; + } + + injectRoutes(router: IAppRouter) { + if (this.isNamedMode) { + this.route(router, { + name: "invoke", + method: "post", + path: "/:alias/invoke", + handler: async (req: express.Request, res: express.Response) => { + await this.asUser(req)._handleInvoke(req, res); + }, + }); + + this.route(router, { + name: "stream", + method: "post", + path: "/:alias/stream", + handler: async (req: express.Request, res: express.Response) => { + await this.asUser(req)._handleStream(req, res); + }, + }); + } else { + this.route(router, { + name: "invoke", + method: "post", + path: "/invoke", + handler: async (req: express.Request, res: express.Response) => { + req.params.alias = "default"; + await this.asUser(req)._handleInvoke(req, res); + }, + }); + + this.route(router, { + name: "stream", + method: "post", + path: "/stream", + handler: async (req: express.Request, res: express.Response) => { + req.params.alias = "default"; + await this.asUser(req)._handleStream(req, res); + }, + }); + } + } + + async _handleInvoke( + req: express.Request, + res: express.Response, + ): Promise { + const { alias } = req.params; + const rawBody = req.body as Record; + + try { + const result = await this.invoke(alias, rawBody); + if (result === undefined) { + res.status(502).json({ error: "Invocation returned no result" }); + return; + } + res.json(result); + } catch (err) { + const message = err instanceof Error ? err.message : "Invocation failed"; + if (err instanceof EndpointNotFoundError) { + res.status(404).json({ error: message }); + } else if ( + err instanceof EndpointNotConfiguredError || + message.startsWith("Unknown request parameters:") + ) { + res.status(400).json({ error: message }); + } else { + res.status(502).json({ error: message }); + } + } + } + + async _handleStream( + req: express.Request, + res: express.Response, + ): Promise { + const { alias } = req.params; + const rawBody = req.body as Record; + + let endpoint: ResolvedEndpoint; + let filteredBody: Record; + try { + ({ endpoint, filteredBody } = this.resolveAndFilter(alias, rawBody)); + } catch (err) { + const message = err instanceof Error ? err.message : "Invalid request"; + const status = err instanceof EndpointNotFoundError ? 404 : 400; + res.status(status).json({ error: message }); + return; + } + + const timeout = this.config.timeout ?? 120_000; + const requestId = + (typeof req.query.requestId === "string" && req.query.requestId) || + randomUUID(); + + const streamSettings: StreamExecutionSettings = { + ...servingStreamDefaults, + default: { + ...servingStreamDefaults.default, + timeout, + }, + stream: { + ...servingStreamDefaults.stream, + streamId: requestId, + }, + }; + + const workspaceClient = getWorkspaceClient(); + if (!workspaceClient.config.host) { + res.status(500).json({ error: "Databricks host not configured" }); + return; + } + + await this.executeStream( + res, + (signal) => + servingConnector.stream(workspaceClient, endpoint.name, filteredBody, { + servedModel: endpoint.servedModel, + signal, + }), + streamSettings, + ); + } + + async invoke(alias: string, body: Record): Promise { + const { endpoint, filteredBody } = this.resolveAndFilter(alias, body); + const workspaceClient = getWorkspaceClient(); + const timeout = this.config.timeout ?? 120_000; + + return this.execute( + () => + servingConnector.invoke(workspaceClient, endpoint.name, filteredBody, { + servedModel: endpoint.servedModel, + }), + { + default: { + ...servingInvokeDefaults, + timeout, + }, + }, + ); + } + + async *stream( + alias: string, + body: Record, + ): AsyncGenerator { + const { endpoint, filteredBody } = this.resolveAndFilter(alias, body); + const workspaceClient = getWorkspaceClient(); + + yield* servingConnector.stream( + workspaceClient, + endpoint.name, + filteredBody, + { servedModel: endpoint.servedModel }, + ); + } + + async shutdown(): Promise { + this.streamManager.abortAll(); + } + + exports(): ServingFactory { + return ((alias?: string) => ({ + invoke: (body: Record) => + this.invoke(alias ?? "default", body), + stream: (body: Record) => + this.stream(alias ?? "default", body), + })) as ServingFactory; + } +} + +/** + * @internal + */ +export const serving = toPlugin(ServingPlugin); diff --git a/packages/appkit/src/plugins/serving/tests/schema-filter.test.ts b/packages/appkit/src/plugins/serving/tests/schema-filter.test.ts new file mode 100644 index 00000000..948b47f9 --- /dev/null +++ b/packages/appkit/src/plugins/serving/tests/schema-filter.test.ts @@ -0,0 +1,141 @@ +import { describe, expect, test, vi } from "vitest"; +import { filterRequestBody, loadEndpointSchemas } from "../schema-filter"; + +vi.mock("node:fs/promises", () => ({ + default: { + readFile: vi.fn(), + }, +})); + +describe("schema-filter", () => { + describe("filterRequestBody", () => { + test("strips unknown keys when allowlist exists", () => { + const allowlists = new Map([ + ["default", new Set(["messages", "temperature"])], + ]); + + const result = filterRequestBody( + { messages: [], temperature: 0.7, unknown_param: true }, + allowlists, + "default", + ); + + expect(result).toEqual({ messages: [], temperature: 0.7 }); + }); + + test("preserves all keys when no allowlist for alias", () => { + const allowlists = new Map>(); + + const body = { messages: [], custom: "value" }; + const result = filterRequestBody(body, allowlists, "default"); + + expect(result).toBe(body); // Same reference, no filtering + }); + + test("returns empty object when all keys are unknown", () => { + const allowlists = new Map([["default", new Set(["messages"])]]); + + const result = filterRequestBody( + { bad1: 1, bad2: 2 }, + allowlists, + "default", + ); + + expect(result).toEqual({}); + }); + + test("returns full body when all keys are allowed", () => { + const allowlists = new Map([["default", new Set(["a", "b", "c"])]]); + + const result = filterRequestBody( + { a: 1, b: 2, c: 3 }, + allowlists, + "default", + ); + + expect(result).toEqual({ a: 1, b: 2, c: 3 }); + }); + + test("throws in reject mode when unknown keys are present", () => { + const allowlists = new Map([["default", new Set(["messages"])]]); + + expect(() => + filterRequestBody( + { messages: [], unknown_param: true }, + allowlists, + "default", + "reject", + ), + ).toThrow("Unknown request parameters: unknown_param"); + }); + + test("does not throw in reject mode when all keys are allowed", () => { + const allowlists = new Map([ + ["default", new Set(["messages", "temperature"])], + ]); + + const result = filterRequestBody( + { messages: [], temperature: 0.7 }, + allowlists, + "default", + "reject", + ); + + expect(result).toEqual({ messages: [], temperature: 0.7 }); + }); + + test("strips in default mode (strip)", () => { + const allowlists = new Map([["default", new Set(["messages"])]]); + + const result = filterRequestBody( + { messages: [], extra: true }, + allowlists, + "default", + "strip", + ); + + expect(result).toEqual({ messages: [] }); + }); + }); + + describe("loadEndpointSchemas", () => { + test("returns empty map when cache file does not exist", async () => { + const fs = (await import("node:fs/promises")).default; + vi.mocked(fs.readFile).mockRejectedValue( + Object.assign(new Error("ENOENT"), { code: "ENOENT" }), + ); + + const result = await loadEndpointSchemas("/nonexistent/path"); + expect(result.size).toBe(0); + }); + + test("extracts property keys from cached types", async () => { + const fs = (await import("node:fs/promises")).default; + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + version: "1", + endpoints: { + default: { + hash: "abc", + requestType: `{ + messages: string[]; + temperature?: number | null; + max_tokens: number; +}`, + responseType: "{}", + chunkType: null, + }, + }, + }), + ); + + const result = await loadEndpointSchemas("/some/path"); + expect(result.size).toBe(1); + const keys = result.get("default"); + expect(keys).toBeDefined(); + expect(keys?.has("messages")).toBe(true); + expect(keys?.has("temperature")).toBe(true); + expect(keys?.has("max_tokens")).toBe(true); + }); + }); +}); diff --git a/packages/appkit/src/plugins/serving/tests/serving.test.ts b/packages/appkit/src/plugins/serving/tests/serving.test.ts new file mode 100644 index 00000000..1a953b77 --- /dev/null +++ b/packages/appkit/src/plugins/serving/tests/serving.test.ts @@ -0,0 +1,339 @@ +import { + createMockRequest, + createMockResponse, + createMockRouter, + mockServiceContext, + setupDatabricksEnv, +} from "@tools/test-helpers"; +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { ServiceContext } from "../../../context/service-context"; +import { ServingPlugin, serving } from "../serving"; +import type { IServingConfig } from "../types"; + +// Mock CacheManager singleton +const { mockCacheInstance } = vi.hoisted(() => { + const instance = { + get: vi.fn(), + set: vi.fn(), + delete: vi.fn(), + getOrExecute: vi + .fn() + .mockImplementation( + async (_key: unknown[], fn: () => Promise) => { + return await fn(); + }, + ), + generateKey: vi.fn((...args: unknown[]) => JSON.stringify(args)), + }; + return { mockCacheInstance: instance }; +}); + +vi.mock("../../../cache", () => ({ + CacheManager: { + getInstanceSync: vi.fn(() => mockCacheInstance), + }, +})); + +// Mock the serving connector +const mockInvoke = vi.fn(); +const mockStream = vi.fn(); + +vi.mock("../../../connectors/serving/client", () => ({ + invoke: (...args: any[]) => mockInvoke(...args), + stream: (...args: any[]) => mockStream(...args), +})); + +describe("Serving Plugin", () => { + let serviceContextMock: Awaited>; + + beforeEach(async () => { + setupDatabricksEnv(); + process.env.DATABRICKS_SERVING_ENDPOINT = "test-endpoint"; + ServiceContext.reset(); + + serviceContextMock = await mockServiceContext(); + }); + + afterEach(() => { + serviceContextMock?.restore(); + delete process.env.DATABRICKS_SERVING_ENDPOINT; + vi.restoreAllMocks(); + }); + + test("serving factory should have correct name", () => { + const pluginData = serving(); + expect(pluginData.name).toBe("serving"); + }); + + test("serving factory with config should have correct name", () => { + const pluginData = serving({ + endpoints: { llm: { env: "DATABRICKS_SERVING_ENDPOINT" } }, + }); + expect(pluginData.name).toBe("serving"); + }); + + describe("default mode", () => { + test("reads DATABRICKS_SERVING_ENDPOINT", () => { + const plugin = new ServingPlugin({}); + const api = (plugin.exports() as any)(); + expect(api.invoke).toBeDefined(); + expect(api.stream).toBeDefined(); + }); + + test("injects /invoke and /stream routes", () => { + const plugin = new ServingPlugin({}); + const { router, handlers } = createMockRouter(); + + plugin.injectRoutes(router); + + expect(handlers["POST:/invoke"]).toBeDefined(); + expect(handlers["POST:/stream"]).toBeDefined(); + }); + + test("exports returns a factory that provides invoke and stream", () => { + const plugin = new ServingPlugin({}); + const factory = plugin.exports() as any; + const api = factory(); + + expect(typeof api.invoke).toBe("function"); + expect(typeof api.stream).toBe("function"); + }); + }); + + describe("named mode", () => { + const namedConfig: IServingConfig = { + endpoints: { + llm: { env: "DATABRICKS_SERVING_ENDPOINT" }, + embedder: { env: "DATABRICKS_SERVING_ENDPOINT_EMBEDDING" }, + }, + }; + + test("injects /:alias/invoke and /:alias/stream routes", () => { + const plugin = new ServingPlugin(namedConfig); + const { router, handlers } = createMockRouter(); + + plugin.injectRoutes(router); + + expect(handlers["POST:/:alias/invoke"]).toBeDefined(); + expect(handlers["POST:/:alias/stream"]).toBeDefined(); + }); + + test("exports factory returns invoke and stream for named aliases", () => { + const plugin = new ServingPlugin(namedConfig); + const factory = plugin.exports() as any; + + expect(typeof factory("llm").invoke).toBe("function"); + expect(typeof factory("llm").stream).toBe("function"); + expect(typeof factory("embedder").invoke).toBe("function"); + expect(typeof factory("embedder").stream).toBe("function"); + }); + }); + + describe("route handlers", () => { + test("_handleInvoke returns 404 for unknown alias", async () => { + const plugin = new ServingPlugin({ + endpoints: { llm: { env: "DATABRICKS_SERVING_ENDPOINT" } }, + }); + + const req = createMockRequest({ + params: { alias: "unknown" }, + body: { messages: [] }, + }); + const res = createMockResponse(); + + await plugin._handleInvoke(req as any, res as any); + + expect(res.status).toHaveBeenCalledWith(404); + expect(res.json).toHaveBeenCalledWith({ + error: "Unknown endpoint alias: unknown", + }); + }); + + test("_handleInvoke calls connector with correct endpoint", async () => { + mockInvoke.mockResolvedValue({ choices: [] }); + + const plugin = new ServingPlugin({}); + const req = createMockRequest({ + params: { alias: "default" }, + body: { messages: [{ role: "user", content: "Hello" }] }, + }); + const res = createMockResponse(); + + await plugin._handleInvoke(req as any, res as any); + + expect(mockInvoke).toHaveBeenCalledWith( + expect.anything(), + "test-endpoint", + { messages: [{ role: "user", content: "Hello" }] }, + { servedModel: undefined }, + ); + expect(res.json).toHaveBeenCalledWith({ choices: [] }); + }); + + test("_handleInvoke returns 400 with descriptive message when env var is not set", async () => { + delete process.env.DATABRICKS_SERVING_ENDPOINT; + + const plugin = new ServingPlugin({}); + const req = createMockRequest({ + params: { alias: "default" }, + body: { messages: [] }, + }); + const res = createMockResponse(); + + await plugin._handleInvoke(req as any, res as any); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.json).toHaveBeenCalledWith({ + error: + "Endpoint 'default' is not configured: env var 'DATABRICKS_SERVING_ENDPOINT' is not set", + }); + }); + + test("_handleInvoke does not throw when connector fails", async () => { + mockInvoke.mockRejectedValue(new Error("Connection refused")); + + const plugin = new ServingPlugin({}); + const req = createMockRequest({ + params: { alias: "default" }, + body: { messages: [] }, + }); + const res = createMockResponse(); + + // Should not throw — execute() handles the error internally + await expect( + plugin._handleInvoke(req as any, res as any), + ).resolves.not.toThrow(); + }); + + test("_handleStream returns 404 for unknown alias", async () => { + const plugin = new ServingPlugin({ + endpoints: { llm: { env: "DATABRICKS_SERVING_ENDPOINT" } }, + }); + + const req = createMockRequest({ + params: { alias: "unknown" }, + body: { messages: [] }, + query: {}, + }); + const res = createMockResponse(); + + await plugin._handleStream(req as any, res as any); + + expect(res.status).toHaveBeenCalledWith(404); + expect(res.json).toHaveBeenCalledWith({ + error: "Unknown endpoint alias: unknown", + }); + }); + + test("_handleStream returns 400 when env var is not set", async () => { + delete process.env.DATABRICKS_SERVING_ENDPOINT; + + const plugin = new ServingPlugin({}); + const req = createMockRequest({ + params: { alias: "default" }, + body: { messages: [] }, + query: {}, + }); + const res = createMockResponse(); + + await plugin._handleStream(req as any, res as any); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.json).toHaveBeenCalledWith({ + error: + "Endpoint 'default' is not configured: env var 'DATABRICKS_SERVING_ENDPOINT' is not set", + }); + }); + }); + + describe("getResourceRequirements", () => { + test("generates requirements for default mode", () => { + const reqs = ServingPlugin.getResourceRequirements({}); + expect(reqs).toHaveLength(1); + expect(reqs[0]).toMatchObject({ + type: "serving_endpoint", + alias: "serving-default", + permission: "CAN_QUERY", + fields: { + name: { + env: "DATABRICKS_SERVING_ENDPOINT", + }, + }, + }); + }); + + test("generates requirements for named mode", () => { + const reqs = ServingPlugin.getResourceRequirements({ + endpoints: { + llm: { env: "LLM_ENDPOINT" }, + embedder: { env: "EMBED_ENDPOINT" }, + }, + }); + expect(reqs).toHaveLength(2); + expect(reqs[0].fields.name.env).toBe("LLM_ENDPOINT"); + expect(reqs[1].fields.name.env).toBe("EMBED_ENDPOINT"); + }); + }); + + describe("programmatic API", () => { + test("invoke calls connector correctly", async () => { + mockInvoke.mockResolvedValue({ + choices: [{ message: { content: "Hi" } }], + }); + + const plugin = new ServingPlugin({}); + const result = await plugin.invoke("default", { messages: [] }); + + expect(mockInvoke).toHaveBeenCalledWith( + expect.anything(), + "test-endpoint", + { messages: [] }, + { servedModel: undefined }, + ); + expect(result).toEqual({ choices: [{ message: { content: "Hi" } }] }); + }); + + test("invoke throws for unknown alias", async () => { + const plugin = new ServingPlugin({ + endpoints: { llm: { env: "DATABRICKS_SERVING_ENDPOINT" } }, + }); + + await expect(plugin.invoke("unknown", { messages: [] })).rejects.toThrow( + "Unknown endpoint alias: unknown", + ); + }); + + test("stream yields chunks from connector", async () => { + const chunks = [ + { choices: [{ delta: { content: "Hello" } }] }, + { choices: [{ delta: { content: " world" } }] }, + ]; + + mockStream.mockImplementation(async function* () { + for (const chunk of chunks) { + yield chunk; + } + }); + + const plugin = new ServingPlugin({}); + const results: unknown[] = []; + for await (const chunk of plugin.stream("default", { messages: [] })) { + results.push(chunk); + } + + expect(results).toEqual(chunks); + }); + }); + + describe("shutdown", () => { + test("calls streamManager.abortAll", async () => { + const plugin = new ServingPlugin({}); + // Accessing the protected streamManager through the plugin + const abortSpy = vi.spyOn((plugin as any).streamManager, "abortAll"); + + await plugin.shutdown(); + + expect(abortSpy).toHaveBeenCalled(); + }); + }); +}); diff --git a/packages/appkit/src/plugins/serving/types.ts b/packages/appkit/src/plugins/serving/types.ts new file mode 100644 index 00000000..9a2dd230 --- /dev/null +++ b/packages/appkit/src/plugins/serving/types.ts @@ -0,0 +1,67 @@ +import type { BasePluginConfig } from "shared"; + +export interface EndpointConfig { + /** Environment variable holding the endpoint name. */ + env: string; + /** Target a specific served model (bypasses traffic routing). */ + servedModel?: string; +} + +export interface IServingConfig extends BasePluginConfig { + /** Map of alias → endpoint config. Defaults to { default: { env: "DATABRICKS_SERVING_ENDPOINT" } } if omitted. */ + endpoints?: Record; + /** Request timeout in ms. Default: 120000 (2 min) */ + timeout?: number; + /** How to handle unknown request parameters. 'strip' silently removes them (default). 'reject' returns 400. */ + filterMode?: "strip" | "reject"; +} + +/** + * Registry interface for serving endpoint type generation. + * Empty by default — augmented by the Vite type generator's `.d.ts` output via module augmentation. + * When populated, provides autocomplete for alias names and typed request/response/chunk per endpoint. + */ +// biome-ignore lint/suspicious/noEmptyInterface: intentionally empty — populated via module augmentation +export interface ServingEndpointRegistry {} + +/** Shape of a single registry entry. */ +export interface ServingEndpointEntry { + request: Record; + response: unknown; + chunk: unknown; +} + +/** Typed invoke/stream methods for a serving endpoint. */ +export interface ServingEndpointMethods< + TRequest extends Record = Record, + TResponse = unknown, + TChunk = unknown, +> { + invoke: (body: TRequest) => Promise; + stream: (body: TRequest) => AsyncGenerator; +} + +/** + * Factory function returned by `AppKit.serving`. + * + * This is a conditional type that adapts based on whether `ServingEndpointRegistry` + * has been populated via module augmentation (generated by `appKitServingTypesPlugin()`): + * + * - **Registry empty (default):** `(alias?: string) => ServingEndpointMethods` — + * accepts any alias string with untyped request/response/chunk. + * - **Registry populated:** `(alias: K) => ServingEndpointMethods<...>` — + * restricts `alias` to known endpoint keys and infers typed request/response/chunk + * from the registry entry. + * + * Run `appKitServingTypesPlugin()` in your Vite config to generate the registry + * augmentation and enable full type safety. + */ +export type ServingFactory = keyof ServingEndpointRegistry extends never + ? (alias?: string) => ServingEndpointMethods + : ( + alias: K, + ) => ServingEndpointMethods< + ServingEndpointRegistry[K]["request"], + ServingEndpointRegistry[K]["response"], + ServingEndpointRegistry[K]["chunk"] + >; diff --git a/packages/appkit/src/stream/stream-manager.ts b/packages/appkit/src/stream/stream-manager.ts index 41764772..8b511fac 100644 --- a/packages/appkit/src/stream/stream-manager.ts +++ b/packages/appkit/src/stream/stream-manager.ts @@ -374,6 +374,14 @@ export class StreamManager { if (error.name === "AbortError") { return SSEErrorCode.STREAM_ABORTED; } + + // Detect upstream API errors (e.g., from Databricks SDK ApiError) + if ( + "statusCode" in error && + typeof (error as any).statusCode === "number" + ) { + return SSEErrorCode.UPSTREAM_ERROR; + } } return SSEErrorCode.INTERNAL_ERROR; diff --git a/packages/appkit/src/stream/types.ts b/packages/appkit/src/stream/types.ts index 0fd862ba..3841bfd1 100644 --- a/packages/appkit/src/stream/types.ts +++ b/packages/appkit/src/stream/types.ts @@ -16,6 +16,7 @@ export const SSEErrorCode = { INVALID_REQUEST: "INVALID_REQUEST", STREAM_ABORTED: "STREAM_ABORTED", STREAM_EVICTED: "STREAM_EVICTED", + UPSTREAM_ERROR: "UPSTREAM_ERROR", } as const satisfies Record; export type SSEErrorCode = (typeof SSEErrorCode)[keyof typeof SSEErrorCode];