diff --git a/src/app/api/fim/completions/route.ts b/src/app/api/fim/completions/route.ts index 8c4bb824f..c468794e8 100644 --- a/src/app/api/fim/completions/route.ts +++ b/src/app/api/fim/completions/route.ts @@ -126,8 +126,8 @@ export async function POST(request: NextRequest) { const promptInfo = extractFimPromptInfo(requestBody); const userByok = organizationId - ? await getBYOKforOrganization(readDb, organizationId, 'codestral') - : await getBYOKforUser(readDb, user.id, 'codestral'); + ? await getBYOKforOrganization(readDb, organizationId, ['codestral']) + : await getBYOKforUser(readDb, user.id, ['codestral']); const usageContext: MicrodollarUsageContext = { kiloUserId: user.id, @@ -188,7 +188,7 @@ export async function POST(request: NextRequest) { method: 'POST', headers: { 'Content-Type': 'application/json', - Authorization: `Bearer ${userByok?.decryptedAPIKey ?? MISTRAL_API_KEY}`, + Authorization: `Bearer ${userByok?.at(0)?.decryptedAPIKey ?? MISTRAL_API_KEY}`, }, body: JSON.stringify(bodyWithCorrectedModel), }); diff --git a/src/lib/byok/index.ts b/src/lib/byok/index.ts index 35e9fb8ca..25cadb35c 100644 --- a/src/lib/byok/index.ts +++ b/src/lib/byok/index.ts @@ -1,28 +1,64 @@ -import type { db } from '@/lib/drizzle'; -import { byok_api_keys } from '@/db/schema'; -import { eq, and, sql } from 'drizzle-orm'; +import { readDb, type db } from '@/lib/drizzle'; +import { byok_api_keys, modelsByProvider } from '@/db/schema'; +import { eq, and, inArray } from 'drizzle-orm'; +import { desc } from 'drizzle-orm'; import { decryptApiKey } from '@/lib/byok/encryption'; import { BYOK_ENCRYPTION_KEY } from '@/lib/config.server'; -import type { UserByokProviderId } from '@/lib/providers/openrouter/inference-provider-id'; +import { + UserByokProviderIdSchema, + VercelUserByokInferenceProviderIdSchema, + type UserByokProviderId, +} from '@/lib/providers/openrouter/inference-provider-id'; +import { isCodestralModel } from '@/lib/providers/mistral'; +import { unstable_cache } from 'next/cache'; +import { mapModelIdToVercel } from '@/lib/providers/vercel/mapModelIdToVercel'; export type BYOKResult = { decryptedAPIKey: string; providerId: UserByokProviderId; }; -/** - * Retrieves a decrypted BYOK API key for a user and provider. - * - * @param userId - The Kilo user ID - * @param providerId - The provider ID (case-insensitive match) - * @returns Object with decrypted API key and provider ID if found, null otherwise - */ +const getModelUserByokProviders_cached = unstable_cache( + async (modelId: string) => { + const vercelModelMetadata = ( + await readDb + .select({ vercel: modelsByProvider.vercel }) + .from(modelsByProvider) + .orderBy(desc(modelsByProvider.id)) + .limit(1) + ).at(0)?.vercel; + if (!vercelModelMetadata) { + console.error('[getModelUserByokProviders_cached] no Vercel model metadata in the database'); + return []; + } + const providers = + vercelModelMetadata[mapModelIdToVercel(modelId)]?.endpoints + .map(ep => VercelUserByokInferenceProviderIdSchema.safeParse(ep.tag).data) + .filter(providerId => providerId !== undefined) ?? []; + if (providers.length === 0) { + console.debug(`[getModelUserByokProviders_cached] no user byok providers for ${modelId}`); + return []; + } + console.debug( + `[getModelUserByokProviders_cached] found user byok providers for ${modelId}`, + providers + ); + return providers; + }, + undefined, + { revalidate: 300 } +); + +export async function getModelUserByokProviders(model: string): Promise { + return isCodestralModel(model) ? ['codestral'] : await getModelUserByokProviders_cached(model); +} + export async function getBYOKforUser( fromDb: typeof db, userId: string, - providerId: UserByokProviderId -): Promise { - const [row] = await fromDb + providerIds: UserByokProviderId[] +): Promise { + const rows = await fromDb .select({ encrypted_api_key: byok_api_keys.encrypted_api_key, provider_id: byok_api_keys.provider_id, @@ -32,33 +68,27 @@ export async function getBYOKforUser( and( eq(byok_api_keys.kilo_user_id, userId), eq(byok_api_keys.is_enabled, true), - sql`lower(${byok_api_keys.provider_id}) = lower(${providerId})` + inArray(byok_api_keys.provider_id, providerIds) ) - ); + ) + .orderBy(byok_api_keys.created_at); - if (!row) { + if (rows.length === 0) { return null; } - return { + return rows.map(row => ({ decryptedAPIKey: decryptApiKey(row.encrypted_api_key, BYOK_ENCRYPTION_KEY), - providerId: row.provider_id as UserByokProviderId, - }; + providerId: UserByokProviderIdSchema.parse(row.provider_id), + })); } -/** - * Retrieves a decrypted BYOK API key for an organization and provider. - * - * @param organizationId - The organization ID - * @param providerId - The provider ID (case-insensitive match) - * @returns Object with decrypted API key and provider ID if found, null otherwise - */ export async function getBYOKforOrganization( fromDb: typeof db, organizationId: string, - providerId: UserByokProviderId -): Promise { - const [row] = await fromDb + providerIds: UserByokProviderId[] +): Promise { + const rows = await fromDb .select({ encrypted_api_key: byok_api_keys.encrypted_api_key, provider_id: byok_api_keys.provider_id, @@ -68,16 +98,17 @@ export async function getBYOKforOrganization( and( eq(byok_api_keys.organization_id, organizationId), eq(byok_api_keys.is_enabled, true), - sql`lower(${byok_api_keys.provider_id}) = lower(${providerId})` + inArray(byok_api_keys.provider_id, providerIds) ) - ); + ) + .orderBy(byok_api_keys.created_at); - if (!row) { + if (rows.length === 0) { return null; } - return { + return rows.map(row => ({ decryptedAPIKey: decryptApiKey(row.encrypted_api_key, BYOK_ENCRYPTION_KEY), - providerId: row.provider_id as UserByokProviderId, - }; + providerId: UserByokProviderIdSchema.parse(row.provider_id), + })); } diff --git a/src/lib/providers/index.ts b/src/lib/providers/index.ts index b52208f98..59ac91461 100644 --- a/src/lib/providers/index.ts +++ b/src/lib/providers/index.ts @@ -20,14 +20,16 @@ import { isHaikuModel, } from '@/lib/providers/anthropic'; import { applyGigaPotatoProviderSettings } from '@/lib/providers/gigapotato'; -import { getBYOKforOrganization, getBYOKforUser, type BYOKResult } from '@/lib/byok'; +import { + getBYOKforOrganization, + getBYOKforUser, + getModelUserByokProviders, + type BYOKResult, +} from '@/lib/byok'; import type { CustomLlm } from '@/db/schema'; import { custom_llm, type User } from '@/db/schema'; import type { OpenRouterInferenceProviderId } from '@/lib/providers/openrouter/inference-provider-id'; -import { - inferUserByokProviderForModel, - OpenRouterInferenceProviderIdSchema, -} from '@/lib/providers/openrouter/inference-provider-id'; +import { OpenRouterInferenceProviderIdSchema } from '@/lib/providers/openrouter/inference-provider-id'; import { applyCoreThinkProviderSettings } from '@/lib/providers/corethink'; import { hasAttemptCompletionTool } from '@/lib/tool-calling'; import { applyGoogleModelSettings, isGeminiModel } from '@/lib/providers/google'; @@ -92,14 +94,15 @@ export async function getProvider( user: User | AnonymousUserContext, organizationId: string | undefined, taskId: string | undefined -): Promise<{ provider: Provider; userByok: BYOKResult | null; customLlm: CustomLlm | null }> { +): Promise<{ provider: Provider; userByok: BYOKResult[] | null; customLlm: CustomLlm | null }> { if (!isAnonymousContext(user)) { - const modelProvider = inferUserByokProviderForModel(requestedModel); - const userByok = !modelProvider - ? null - : organizationId - ? await getBYOKforOrganization(db, organizationId, modelProvider) - : await getBYOKforUser(db, user.id, modelProvider); + const modelProviders = await getModelUserByokProviders(requestedModel); + const userByok = + modelProviders.length === 0 + ? null + : organizationId + ? await getBYOKforOrganization(db, organizationId, modelProviders) + : await getBYOKforUser(db, user.id, modelProviders); if (userByok) { return { provider: PROVIDERS.VERCEL_AI_GATEWAY, userByok, customLlm: null }; } @@ -225,7 +228,7 @@ export function applyProviderSpecificLogic( requestedModel: string, requestToMutate: OpenRouterChatCompletionRequest, extraHeaders: Record, - userByok: BYOKResult | null + userByok: BYOKResult[] | null ) { const kiloFreeModel = kiloFreeModels.find(m => m.public_id === requestedModel); if (kiloFreeModel) { diff --git a/src/lib/providers/mistral.ts b/src/lib/providers/mistral.ts index c5ec630c4..69e00e0e8 100644 --- a/src/lib/providers/mistral.ts +++ b/src/lib/providers/mistral.ts @@ -8,6 +8,9 @@ import { export function isMistralModel(model: string) { return model.startsWith('mistralai/'); } +export function isCodestralModel(model: string) { + return model.startsWith('mistralai/codestral'); +} export function applyMistralModelSettings(requestToMutate: OpenRouterChatCompletionRequest) { // mistral recommends this diff --git a/src/lib/providers/openrouter/inference-provider-id.ts b/src/lib/providers/openrouter/inference-provider-id.ts index e39477936..0ff353950 100644 --- a/src/lib/providers/openrouter/inference-provider-id.ts +++ b/src/lib/providers/openrouter/inference-provider-id.ts @@ -87,14 +87,6 @@ const modelPrefixToVercelInferenceProviderMapping = { 'z-ai': VercelUserByokInferenceProviderIdSchema.enum.zai, } as Record; -export function inferUserByokProviderForModel(model: string): UserByokProviderId | null { - return model.startsWith('mistralai/codestral') - ? AutocompleteUserByokProviderIdSchema.enum.codestral - : (VercelUserByokInferenceProviderIdSchema.safeParse( - inferVercelFirstPartyInferenceProviderForModel(model) - ).data ?? null); -} - export function inferVercelFirstPartyInferenceProviderForModel( model: string ): VercelInferenceProviderId | null { diff --git a/src/lib/providers/vercel/index.ts b/src/lib/providers/vercel/index.ts index 94c19d98a..ab1841f14 100644 --- a/src/lib/providers/vercel/index.ts +++ b/src/lib/providers/vercel/index.ts @@ -4,7 +4,6 @@ import { isAnthropicModel } from '@/lib/providers/anthropic'; import { getGatewayErrorRate } from '@/lib/providers/gateway-error-rate'; import { AutocompleteUserByokProviderIdSchema, - inferVercelFirstPartyInferenceProviderForModel, openRouterToVercelInferenceProviderId, VercelUserByokInferenceProviderIdSchema, } from '@/lib/providers/openrouter/inference-provider-id'; @@ -14,6 +13,7 @@ import type { VercelInferenceProviderConfig, VercelProviderConfig, } from '@/lib/providers/openrouter/types'; +import { mapModelIdToVercel } from '@/lib/providers/vercel/mapModelIdToVercel'; import * as crypto from 'crypto'; // EMERGENCY SWITCH @@ -101,28 +101,13 @@ function convertProviderOptions( }; } -const vercelModelIdMapping = { - 'arcee-ai/trinity-large-preview:free': 'arcee-ai/trinity-large-preview', - 'mistralai/codestral-2508': 'mistral/codestral', - 'mistralai/devstral-2512': 'mistral/devstral-2', -} as Record; - export function applyVercelSettings( requestedModel: string, requestToMutate: OpenRouterChatCompletionRequest, extraHeaders: Record, - userByok: BYOKResult | null + userByok: BYOKResult[] | null ) { - const vercelModelId = vercelModelIdMapping[requestedModel]; - if (vercelModelId) { - requestToMutate.model = vercelModelId; - } else { - const firstPartyProvider = inferVercelFirstPartyInferenceProviderForModel(requestedModel); - const slashIndex = requestToMutate.model.indexOf('/'); - if (firstPartyProvider && slashIndex >= 0) { - requestToMutate.model = firstPartyProvider + requestToMutate.model.slice(slashIndex); - } - } + requestToMutate.model = mapModelIdToVercel(requestedModel); if (isAnthropicModel(requestedModel)) { // https://vercel.com/docs/ai-gateway/model-variants#anthropic-claude-sonnet-4:-1m-token-context-beta @@ -133,28 +118,33 @@ export function applyVercelSettings( } if (userByok) { - const provider = - userByok.providerId === AutocompleteUserByokProviderIdSchema.enum.codestral - ? VercelUserByokInferenceProviderIdSchema.enum.mistral - : userByok.providerId; - const list = new Array(); - // Z.AI Coding Plan support - if (provider === VercelUserByokInferenceProviderIdSchema.enum.zai) { - list.push({ - apiKey: userByok.decryptedAPIKey, - baseURL: 'https://api.z.ai/api/coding/paas/v4', - }); + if (userByok.length === 0) { + throw new Error('Invalid state: userByok should be null or not empty'); + } + const byokProviders: Record = {}; + for (const provider of userByok) { + const key = + provider.providerId === AutocompleteUserByokProviderIdSchema.enum.codestral + ? VercelUserByokInferenceProviderIdSchema.enum.mistral + : provider.providerId; + const list = new Array(); + if (key === VercelUserByokInferenceProviderIdSchema.enum.zai) { + // Z.AI Coding Plan support + list.push({ + apiKey: provider.decryptedAPIKey, + baseURL: 'https://api.z.ai/api/coding/paas/v4', + }); + } + list.push({ apiKey: provider.decryptedAPIKey }); + byokProviders[key] = [...(byokProviders[key] ?? []), ...list]; } - list.push({ apiKey: userByok.decryptedAPIKey }); // this is vercel specific BYOK configuration to force vercel gateway to use the BYOK API key // for the user/org. If the key is invalid the request will faill - it will not fall back to bill our API key. requestToMutate.providerOptions = { gateway: { - only: [provider], - byok: { - [provider]: list, - }, + only: Object.keys(byokProviders), + byok: byokProviders, }, }; } else { diff --git a/src/lib/providers/vercel/mapModelIdToVercel.ts b/src/lib/providers/vercel/mapModelIdToVercel.ts new file mode 100644 index 000000000..6d109262d --- /dev/null +++ b/src/lib/providers/vercel/mapModelIdToVercel.ts @@ -0,0 +1,27 @@ +import { kiloFreeModels } from '@/lib/models'; +import { inferVercelFirstPartyInferenceProviderForModel } from '@/lib/providers/openrouter/inference-provider-id'; + +const vercelModelIdMapping: Record = { + 'arcee-ai/trinity-large-preview:free': 'arcee-ai/trinity-large-preview', + 'mistralai/codestral-2508': 'mistral/codestral', + 'mistralai/devstral-2512': 'mistral/devstral-2', +}; + +export function mapModelIdToVercel(modelId: string) { + const hardcodedVercelId = vercelModelIdMapping[modelId]; + if (hardcodedVercelId) { + return hardcodedVercelId; + } + + const internalId = + kiloFreeModels.find(m => m.public_id === modelId && m.is_enabled && m.gateway === 'openrouter') + ?.internal_id ?? modelId; + + const slashIndex = internalId.indexOf('/'); + if (slashIndex < 0) { + return internalId; + } + + const firstPartyProvider = inferVercelFirstPartyInferenceProviderForModel(internalId); + return firstPartyProvider ? firstPartyProvider + internalId.slice(slashIndex) : internalId; +}