Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/app/api/fim/completions/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ export async function POST(request: NextRequest) {
const promptInfo = extractFimPromptInfo(requestBody);

const userByok = organizationId
? await getBYOKforOrganization(readDb, organizationId, 'codestral')
: await getBYOKforUser(readDb, user.id, 'codestral');
? await getBYOKforOrganization(readDb, organizationId, ['codestral'])
: await getBYOKforUser(readDb, user.id, ['codestral']);

const usageContext: MicrodollarUsageContext = {
kiloUserId: user.id,
Expand Down Expand Up @@ -188,7 +188,7 @@ export async function POST(request: NextRequest) {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${userByok?.decryptedAPIKey ?? MISTRAL_API_KEY}`,
Authorization: `Bearer ${userByok?.at(0)?.decryptedAPIKey ?? MISTRAL_API_KEY}`,
},
body: JSON.stringify(bodyWithCorrectedModel),
});
Expand Down
103 changes: 67 additions & 36 deletions src/lib/byok/index.ts
Original file line number Diff line number Diff line change
@@ -1,28 +1,64 @@
import type { db } from '@/lib/drizzle';
import { byok_api_keys } from '@/db/schema';
import { eq, and, sql } from 'drizzle-orm';
import { readDb, type db } from '@/lib/drizzle';
import { byok_api_keys, modelsByProvider } from '@/db/schema';
import { eq, and, inArray } from 'drizzle-orm';
import { desc } from 'drizzle-orm';
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SUGGESTION: Duplicate drizzle-orm import — desc can be combined with the import on line 3.

Suggested change
import { desc } from 'drizzle-orm';
import { eq, and, inArray, desc } from 'drizzle-orm';

import { decryptApiKey } from '@/lib/byok/encryption';
import { BYOK_ENCRYPTION_KEY } from '@/lib/config.server';
import type { UserByokProviderId } from '@/lib/providers/openrouter/inference-provider-id';
import {
UserByokProviderIdSchema,
VercelUserByokInferenceProviderIdSchema,
type UserByokProviderId,
} from '@/lib/providers/openrouter/inference-provider-id';
import { isCodestralModel } from '@/lib/providers/mistral';
import { unstable_cache } from 'next/cache';
import { mapModelIdToVercel } from '@/lib/providers/vercel/mapModelIdToVercel';

export type BYOKResult = {
decryptedAPIKey: string;
providerId: UserByokProviderId;
};

/**
* Retrieves a decrypted BYOK API key for a user and provider.
*
* @param userId - The Kilo user ID
* @param providerId - The provider ID (case-insensitive match)
* @returns Object with decrypted API key and provider ID if found, null otherwise
*/
const getModelUserByokProviders_cached = unstable_cache(
async (modelId: string) => {
const vercelModelMetadata = (
await readDb
.select({ vercel: modelsByProvider.vercel })
.from(modelsByProvider)
.orderBy(desc(modelsByProvider.id))
.limit(1)
).at(0)?.vercel;
if (!vercelModelMetadata) {
console.error('[getModelUserByokProviders_cached] no Vercel model metadata in the database');
return [];
}
const providers =
vercelModelMetadata[mapModelIdToVercel(modelId)]?.endpoints
.map(ep => VercelUserByokInferenceProviderIdSchema.safeParse(ep.tag).data)
.filter(providerId => providerId !== undefined) ?? [];
if (providers.length === 0) {
console.debug(`[getModelUserByokProviders_cached] no user byok providers for ${modelId}`);
return [];
}
console.debug(
`[getModelUserByokProviders_cached] found user byok providers for ${modelId}`,
providers
);
return providers;
},
undefined,
{ revalidate: 300 }
);

export async function getModelUserByokProviders(model: string): Promise<UserByokProviderId[]> {
return isCodestralModel(model) ? ['codestral'] : await getModelUserByokProviders_cached(model);
}

export async function getBYOKforUser(
fromDb: typeof db,
userId: string,
providerId: UserByokProviderId
): Promise<BYOKResult | null> {
const [row] = await fromDb
providerIds: UserByokProviderId[]
): Promise<BYOKResult[] | null> {
const rows = await fromDb
.select({
encrypted_api_key: byok_api_keys.encrypted_api_key,
provider_id: byok_api_keys.provider_id,
Expand All @@ -32,33 +68,27 @@ export async function getBYOKforUser(
and(
eq(byok_api_keys.kilo_user_id, userId),
eq(byok_api_keys.is_enabled, true),
sql`lower(${byok_api_keys.provider_id}) = lower(${providerId})`
inArray(byok_api_keys.provider_id, providerIds)
)
);
)
.orderBy(byok_api_keys.created_at);

if (!row) {
if (rows.length === 0) {
return null;
}

return {
return rows.map(row => ({
decryptedAPIKey: decryptApiKey(row.encrypted_api_key, BYOK_ENCRYPTION_KEY),
providerId: row.provider_id as UserByokProviderId,
};
providerId: UserByokProviderIdSchema.parse(row.provider_id),
}));
}

/**
* Retrieves a decrypted BYOK API key for an organization and provider.
*
* @param organizationId - The organization ID
* @param providerId - The provider ID (case-insensitive match)
* @returns Object with decrypted API key and provider ID if found, null otherwise
*/
export async function getBYOKforOrganization(
fromDb: typeof db,
organizationId: string,
providerId: UserByokProviderId
): Promise<BYOKResult | null> {
const [row] = await fromDb
providerIds: UserByokProviderId[]
): Promise<BYOKResult[] | null> {
const rows = await fromDb
.select({
encrypted_api_key: byok_api_keys.encrypted_api_key,
provider_id: byok_api_keys.provider_id,
Expand All @@ -68,16 +98,17 @@ export async function getBYOKforOrganization(
and(
eq(byok_api_keys.organization_id, organizationId),
eq(byok_api_keys.is_enabled, true),
sql`lower(${byok_api_keys.provider_id}) = lower(${providerId})`
inArray(byok_api_keys.provider_id, providerIds)
)
);
)
.orderBy(byok_api_keys.created_at);

if (!row) {
if (rows.length === 0) {
return null;
}

return {
return rows.map(row => ({
decryptedAPIKey: decryptApiKey(row.encrypted_api_key, BYOK_ENCRYPTION_KEY),
providerId: row.provider_id as UserByokProviderId,
};
providerId: UserByokProviderIdSchema.parse(row.provider_id),
}));
}
29 changes: 16 additions & 13 deletions src/lib/providers/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,16 @@ import {
isHaikuModel,
} from '@/lib/providers/anthropic';
import { applyGigaPotatoProviderSettings } from '@/lib/providers/gigapotato';
import { getBYOKforOrganization, getBYOKforUser, type BYOKResult } from '@/lib/byok';
import {
getBYOKforOrganization,
getBYOKforUser,
getModelUserByokProviders,
type BYOKResult,
} from '@/lib/byok';
import type { CustomLlm } from '@/db/schema';
import { custom_llm, type User } from '@/db/schema';
import type { OpenRouterInferenceProviderId } from '@/lib/providers/openrouter/inference-provider-id';
import {
inferUserByokProviderForModel,
OpenRouterInferenceProviderIdSchema,
} from '@/lib/providers/openrouter/inference-provider-id';
import { OpenRouterInferenceProviderIdSchema } from '@/lib/providers/openrouter/inference-provider-id';
import { applyCoreThinkProviderSettings } from '@/lib/providers/corethink';
import { hasAttemptCompletionTool } from '@/lib/tool-calling';
import { applyGoogleModelSettings, isGeminiModel } from '@/lib/providers/google';
Expand Down Expand Up @@ -92,14 +94,15 @@ export async function getProvider(
user: User | AnonymousUserContext,
organizationId: string | undefined,
taskId: string | undefined
): Promise<{ provider: Provider; userByok: BYOKResult | null; customLlm: CustomLlm | null }> {
): Promise<{ provider: Provider; userByok: BYOKResult[] | null; customLlm: CustomLlm | null }> {
if (!isAnonymousContext(user)) {
const modelProvider = inferUserByokProviderForModel(requestedModel);
const userByok = !modelProvider
? null
: organizationId
? await getBYOKforOrganization(db, organizationId, modelProvider)
: await getBYOKforUser(db, user.id, modelProvider);
const modelProviders = await getModelUserByokProviders(requestedModel);
const userByok =
modelProviders.length === 0
? null
: organizationId
? await getBYOKforOrganization(db, organizationId, modelProviders)
: await getBYOKforUser(db, user.id, modelProviders);
if (userByok) {
return { provider: PROVIDERS.VERCEL_AI_GATEWAY, userByok, customLlm: null };
}
Expand Down Expand Up @@ -225,7 +228,7 @@ export function applyProviderSpecificLogic(
requestedModel: string,
requestToMutate: OpenRouterChatCompletionRequest,
extraHeaders: Record<string, string>,
userByok: BYOKResult | null
userByok: BYOKResult[] | null
) {
const kiloFreeModel = kiloFreeModels.find(m => m.public_id === requestedModel);
if (kiloFreeModel) {
Expand Down
3 changes: 3 additions & 0 deletions src/lib/providers/mistral.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ import {
export function isMistralModel(model: string) {
return model.startsWith('mistralai/');
}
export function isCodestralModel(model: string) {
return model.startsWith('mistralai/codestral');
}

export function applyMistralModelSettings(requestToMutate: OpenRouterChatCompletionRequest) {
// mistral recommends this
Expand Down
8 changes: 0 additions & 8 deletions src/lib/providers/openrouter/inference-provider-id.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,6 @@ const modelPrefixToVercelInferenceProviderMapping = {
'z-ai': VercelUserByokInferenceProviderIdSchema.enum.zai,
} as Record<string, VercelInferenceProviderId | undefined>;

export function inferUserByokProviderForModel(model: string): UserByokProviderId | null {
return model.startsWith('mistralai/codestral')
? AutocompleteUserByokProviderIdSchema.enum.codestral
: (VercelUserByokInferenceProviderIdSchema.safeParse(
inferVercelFirstPartyInferenceProviderForModel(model)
).data ?? null);
}

export function inferVercelFirstPartyInferenceProviderForModel(
model: string
): VercelInferenceProviderId | null {
Expand Down
58 changes: 24 additions & 34 deletions src/lib/providers/vercel/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import { isAnthropicModel } from '@/lib/providers/anthropic';
import { getGatewayErrorRate } from '@/lib/providers/gateway-error-rate';
import {
AutocompleteUserByokProviderIdSchema,
inferVercelFirstPartyInferenceProviderForModel,
openRouterToVercelInferenceProviderId,
VercelUserByokInferenceProviderIdSchema,
} from '@/lib/providers/openrouter/inference-provider-id';
Expand All @@ -14,6 +13,7 @@ import type {
VercelInferenceProviderConfig,
VercelProviderConfig,
} from '@/lib/providers/openrouter/types';
import { mapModelIdToVercel } from '@/lib/providers/vercel/mapModelIdToVercel';
import * as crypto from 'crypto';

// EMERGENCY SWITCH
Expand Down Expand Up @@ -101,28 +101,13 @@ function convertProviderOptions(
};
}

const vercelModelIdMapping = {
'arcee-ai/trinity-large-preview:free': 'arcee-ai/trinity-large-preview',
'mistralai/codestral-2508': 'mistral/codestral',
'mistralai/devstral-2512': 'mistral/devstral-2',
} as Record<string, string>;

export function applyVercelSettings(
requestedModel: string,
requestToMutate: OpenRouterChatCompletionRequest,
extraHeaders: Record<string, string>,
userByok: BYOKResult | null
userByok: BYOKResult[] | null
) {
const vercelModelId = vercelModelIdMapping[requestedModel];
if (vercelModelId) {
requestToMutate.model = vercelModelId;
} else {
const firstPartyProvider = inferVercelFirstPartyInferenceProviderForModel(requestedModel);
const slashIndex = requestToMutate.model.indexOf('/');
if (firstPartyProvider && slashIndex >= 0) {
requestToMutate.model = firstPartyProvider + requestToMutate.model.slice(slashIndex);
}
}
requestToMutate.model = mapModelIdToVercel(requestedModel);

if (isAnthropicModel(requestedModel)) {
// https://vercel.com/docs/ai-gateway/model-variants#anthropic-claude-sonnet-4:-1m-token-context-beta
Expand All @@ -133,28 +118,33 @@ export function applyVercelSettings(
}

if (userByok) {
const provider =
userByok.providerId === AutocompleteUserByokProviderIdSchema.enum.codestral
? VercelUserByokInferenceProviderIdSchema.enum.mistral
: userByok.providerId;
const list = new Array<VercelInferenceProviderConfig>();
// Z.AI Coding Plan support
if (provider === VercelUserByokInferenceProviderIdSchema.enum.zai) {
list.push({
apiKey: userByok.decryptedAPIKey,
baseURL: 'https://api.z.ai/api/coding/paas/v4',
});
if (userByok.length === 0) {
throw new Error('Invalid state: userByok should be null or not empty');
}
const byokProviders: Record<string, VercelInferenceProviderConfig[]> = {};
for (const provider of userByok) {
const key =
provider.providerId === AutocompleteUserByokProviderIdSchema.enum.codestral
? VercelUserByokInferenceProviderIdSchema.enum.mistral
: provider.providerId;
const list = new Array<VercelInferenceProviderConfig>();
if (key === VercelUserByokInferenceProviderIdSchema.enum.zai) {
// Z.AI Coding Plan support
list.push({
apiKey: provider.decryptedAPIKey,
baseURL: 'https://api.z.ai/api/coding/paas/v4',
});
}
list.push({ apiKey: provider.decryptedAPIKey });
byokProviders[key] = [...(byokProviders[key] ?? []), ...list];
}
list.push({ apiKey: userByok.decryptedAPIKey });

// this is vercel specific BYOK configuration to force vercel gateway to use the BYOK API key
// for the user/org. If the key is invalid the request will faill - it will not fall back to bill our API key.
requestToMutate.providerOptions = {
gateway: {
only: [provider],
byok: {
[provider]: list,
},
only: Object.keys(byokProviders),
byok: byokProviders,
},
};
} else {
Expand Down
27 changes: 27 additions & 0 deletions src/lib/providers/vercel/mapModelIdToVercel.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import { kiloFreeModels } from '@/lib/models';
import { inferVercelFirstPartyInferenceProviderForModel } from '@/lib/providers/openrouter/inference-provider-id';

const vercelModelIdMapping: Record<string, string | undefined> = {
'arcee-ai/trinity-large-preview:free': 'arcee-ai/trinity-large-preview',
'mistralai/codestral-2508': 'mistral/codestral',
'mistralai/devstral-2512': 'mistral/devstral-2',
};

export function mapModelIdToVercel(modelId: string) {
const hardcodedVercelId = vercelModelIdMapping[modelId];
if (hardcodedVercelId) {
return hardcodedVercelId;
}

const internalId =
kiloFreeModels.find(m => m.public_id === modelId && m.is_enabled && m.gateway === 'openrouter')
?.internal_id ?? modelId;

const slashIndex = internalId.indexOf('/');
if (slashIndex < 0) {
return internalId;
}

const firstPartyProvider = inferVercelFirstPartyInferenceProviderForModel(internalId);
return firstPartyProvider ? firstPartyProvider + internalId.slice(slashIndex) : internalId;
}