diff --git a/src/api/providers/__tests__/chutes.spec.ts b/src/api/providers/__tests__/chutes.spec.ts index c89ccb79907..22da3500034 100644 --- a/src/api/providers/__tests__/chutes.spec.ts +++ b/src/api/providers/__tests__/chutes.spec.ts @@ -1,336 +1,490 @@ // npx vitest run api/providers/__tests__/chutes.spec.ts -import { Anthropic } from "@anthropic-ai/sdk" -import OpenAI from "openai" +const { mockStreamText, mockGenerateText, mockGetModels, mockGetModelsFromCache } = vi.hoisted(() => ({ + mockStreamText: vi.fn(), + mockGenerateText: vi.fn(), + mockGetModels: vi.fn(), + mockGetModelsFromCache: vi.fn(), +})) -import { chutesDefaultModelId, chutesDefaultModelInfo, DEEP_SEEK_DEFAULT_TEMPERATURE } from "@roo-code/types" +vi.mock("ai", async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + streamText: mockStreamText, + generateText: mockGenerateText, + } +}) -import { ChutesHandler } from "../chutes" +vi.mock("@ai-sdk/openai-compatible", () => ({ + createOpenAICompatible: vi.fn(() => { + return vi.fn((modelId: string) => ({ + modelId, + provider: "chutes", + })) + }), +})) -// Create mock functions -const mockCreate = vi.fn() -const mockFetchModel = vi.fn() - -// Mock OpenAI module -vi.mock("openai", () => ({ - default: vi.fn(() => ({ - chat: { - completions: { - create: mockCreate, - }, - }, - })), +vi.mock("../fetchers/modelCache", () => ({ + getModels: mockGetModels, + getModelsFromCache: mockGetModelsFromCache, })) +import type { Anthropic } from "@anthropic-ai/sdk" + +import { chutesDefaultModelId, chutesDefaultModelInfo, DEEP_SEEK_DEFAULT_TEMPERATURE } from "@roo-code/types" + +import { ChutesHandler } from "../chutes" + describe("ChutesHandler", () => { let handler: ChutesHandler beforeEach(() => { vi.clearAllMocks() - // Set up default mock implementation - mockCreate.mockImplementation(async () => ({ - [Symbol.asyncIterator]: async function* () { - yield { - choices: [ - { - delta: { content: "Test response" }, - index: 0, - }, - ], - usage: null, - } - yield { - choices: [ - { - delta: {}, - index: 0, - }, - ], - usage: { - prompt_tokens: 10, - completion_tokens: 5, - total_tokens: 15, - }, - } - }, - })) - handler = new ChutesHandler({ chutesApiKey: "test-key" }) - // Mock fetchModel to return default model - mockFetchModel.mockResolvedValue({ - id: chutesDefaultModelId, - info: chutesDefaultModelInfo, + mockGetModels.mockResolvedValue({ + [chutesDefaultModelId]: chutesDefaultModelInfo, }) - handler.fetchModel = mockFetchModel + mockGetModelsFromCache.mockReturnValue(undefined) + handler = new ChutesHandler({ chutesApiKey: "test-key" }) }) afterEach(() => { vi.restoreAllMocks() }) - it("should use the correct Chutes base URL", () => { - new ChutesHandler({ chutesApiKey: "test-chutes-api-key" }) - expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ baseURL: "https://llm.chutes.ai/v1" })) - }) + describe("constructor", () => { + it("should initialize with provided options", () => { + expect(handler).toBeInstanceOf(ChutesHandler) + }) - it("should use the provided API key", () => { - const chutesApiKey = "test-chutes-api-key" - new ChutesHandler({ chutesApiKey }) - expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ apiKey: chutesApiKey })) + it("should use default model when no model ID is provided", () => { + const model = handler.getModel() + expect(model.id).toBe(chutesDefaultModelId) + }) }) - it("should handle DeepSeek R1 reasoning format", async () => { - // Override the mock for this specific test - mockCreate.mockImplementationOnce(async () => ({ - [Symbol.asyncIterator]: async function* () { - yield { - choices: [ - { - delta: { content: "Thinking..." }, - index: 0, - }, - ], - usage: null, - } - yield { - choices: [ - { - delta: { content: "Hello" }, - index: 0, - }, - ], - usage: null, - } - yield { - choices: [ - { - delta: {}, - index: 0, - }, - ], - usage: { prompt_tokens: 10, completion_tokens: 5 }, - } - }, - })) + describe("getModel", () => { + it("should return default model when no model is specified and no cache", () => { + const model = handler.getModel() + expect(model.id).toBe(chutesDefaultModelId) + expect(model.info).toEqual( + expect.objectContaining({ + ...chutesDefaultModelInfo, + }), + ) + }) - const systemPrompt = "You are a helpful assistant." - const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hi" }] - mockFetchModel.mockResolvedValueOnce({ - id: "deepseek-ai/DeepSeek-R1-0528", - info: { maxTokens: 1024, temperature: 0.7 }, + it("should return model info from fetched models", async () => { + const testModelInfo = { + maxTokens: 4096, + contextWindow: 128000, + supportsImages: false, + supportsPromptCache: false, + } + mockGetModels.mockResolvedValue({ + "some-model": testModelInfo, + }) + + const handlerWithModel = new ChutesHandler({ + apiModelId: "some-model", + chutesApiKey: "test-key", + }) + const model = await handlerWithModel.fetchModel() + expect(model.id).toBe("some-model") + expect(model.info).toEqual(expect.objectContaining(testModelInfo)) + }) + + it("should fall back to global cache when instance models are empty", () => { + const cachedInfo = { + maxTokens: 2048, + contextWindow: 64000, + supportsImages: false, + supportsPromptCache: false, + } + mockGetModelsFromCache.mockReturnValue({ + "cached-model": cachedInfo, + }) + + const handlerWithModel = new ChutesHandler({ + apiModelId: "cached-model", + chutesApiKey: "test-key", + }) + const model = handlerWithModel.getModel() + expect(model.id).toBe("cached-model") + expect(model.info).toEqual(expect.objectContaining(cachedInfo)) + }) + + it("should apply DeepSeek default temperature for R1 models", () => { + const r1Info = { + maxTokens: 32768, + contextWindow: 163840, + supportsImages: false, + supportsPromptCache: false, + } + mockGetModelsFromCache.mockReturnValue({ + "deepseek-ai/DeepSeek-R1-0528": r1Info, + }) + + const handlerWithModel = new ChutesHandler({ + apiModelId: "deepseek-ai/DeepSeek-R1-0528", + chutesApiKey: "test-key", + }) + const model = handlerWithModel.getModel() + expect(model.info.defaultTemperature).toBe(DEEP_SEEK_DEFAULT_TEMPERATURE) + expect(model.temperature).toBe(DEEP_SEEK_DEFAULT_TEMPERATURE) + }) + + it("should use default temperature for non-DeepSeek models", () => { + const modelInfo = { + maxTokens: 4096, + contextWindow: 128000, + supportsImages: false, + supportsPromptCache: false, + } + mockGetModelsFromCache.mockReturnValue({ + "unsloth/Llama-3.3-70B-Instruct": modelInfo, + }) + + const handlerWithModel = new ChutesHandler({ + apiModelId: "unsloth/Llama-3.3-70B-Instruct", + chutesApiKey: "test-key", + }) + const model = handlerWithModel.getModel() + expect(model.info.defaultTemperature).toBe(0.5) + expect(model.temperature).toBe(0.5) }) + }) - const stream = handler.createMessage(systemPrompt, messages) - const chunks = [] - for await (const chunk of stream) { - chunks.push(chunk) - } - - expect(chunks).toEqual([ - { type: "reasoning", text: "Thinking..." }, - { type: "text", text: "Hello" }, - { type: "usage", inputTokens: 10, outputTokens: 5 }, - ]) + describe("fetchModel", () => { + it("should fetch models and return the resolved model", async () => { + const model = await handler.fetchModel() + expect(mockGetModels).toHaveBeenCalledWith( + expect.objectContaining({ + provider: "chutes", + }), + ) + expect(model.id).toBe(chutesDefaultModelId) + }) }) - it("should handle non-DeepSeek models", async () => { - // Use default mock implementation which returns text content + describe("createMessage", () => { const systemPrompt = "You are a helpful assistant." const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hi" }] - mockFetchModel.mockResolvedValueOnce({ - id: "some-other-model", - info: { maxTokens: 1024, temperature: 0.7 }, - }) - const stream = handler.createMessage(systemPrompt, messages) - const chunks = [] - for await (const chunk of stream) { - chunks.push(chunk) - } + it("should handle non-DeepSeek models with standard streaming", async () => { + mockGetModels.mockResolvedValue({ + "some-other-model": { maxTokens: 1024, contextWindow: 8192, supportsPromptCache: false }, + }) - expect(chunks).toEqual([ - { type: "text", text: "Test response" }, - { type: "usage", inputTokens: 10, outputTokens: 5 }, - ]) - }) + async function* mockFullStream() { + yield { type: "text-delta", text: "Test response" } + } - it("should return default model when no model is specified", async () => { - const model = await handler.fetchModel() - expect(model.id).toBe(chutesDefaultModelId) - expect(model.info).toEqual(expect.objectContaining(chutesDefaultModelInfo)) - }) + const mockUsage = Promise.resolve({ + inputTokens: 10, + outputTokens: 5, + }) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + }) + + const handlerWithModel = new ChutesHandler({ + apiModelId: "some-other-model", + chutesApiKey: "test-key", + }) + + const stream = handlerWithModel.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } - it("should return specified model when valid model is provided", async () => { - const testModelId = "deepseek-ai/DeepSeek-R1" - const handlerWithModel = new ChutesHandler({ - apiModelId: testModelId, - chutesApiKey: "test-chutes-api-key", - }) - // Mock fetchModel for this handler to return the test model from dynamic fetch - handlerWithModel.fetchModel = vi.fn().mockResolvedValue({ - id: testModelId, - info: { maxTokens: 32768, contextWindow: 163840, supportsImages: false, supportsPromptCache: false }, + expect(chunks).toEqual([ + { type: "text", text: "Test response" }, + { + type: "usage", + inputTokens: 10, + outputTokens: 5, + cacheReadTokens: undefined, + reasoningTokens: undefined, + }, + ]) }) - const model = await handlerWithModel.fetchModel() - expect(model.id).toBe(testModelId) - }) - it("completePrompt method should return text from Chutes API", async () => { - const expectedResponse = "This is a test response from Chutes" - mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: expectedResponse } }] }) - const result = await handler.completePrompt("test prompt") - expect(result).toBe(expectedResponse) - }) + it("should handle DeepSeek R1 reasoning format with TagMatcher", async () => { + mockGetModels.mockResolvedValue({ + "deepseek-ai/DeepSeek-R1-0528": { + maxTokens: 32768, + contextWindow: 163840, + supportsImages: false, + supportsPromptCache: false, + }, + }) - it("should handle errors in completePrompt", async () => { - const errorMessage = "Chutes API error" - mockCreate.mockRejectedValueOnce(new Error(errorMessage)) - await expect(handler.completePrompt("test prompt")).rejects.toThrow(`Chutes completion error: ${errorMessage}`) - }) + async function* mockFullStream() { + yield { type: "text-delta", text: "Thinking..." } + yield { type: "text-delta", text: "Hello" } + } - it("createMessage should yield text content from stream", async () => { - const testContent = "This is test content from Chutes stream" - - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - next: vi - .fn() - .mockResolvedValueOnce({ - done: false, - value: { choices: [{ delta: { content: testContent } }] }, - }) - .mockResolvedValueOnce({ done: true }), - }), + const mockUsage = Promise.resolve({ + inputTokens: 10, + outputTokens: 5, + }) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + }) + + const handlerWithModel = new ChutesHandler({ + apiModelId: "deepseek-ai/DeepSeek-R1-0528", + chutesApiKey: "test-key", + }) + + const stream = handlerWithModel.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) } + + expect(chunks).toEqual([ + { type: "reasoning", text: "Thinking..." }, + { type: "text", text: "Hello" }, + { + type: "usage", + inputTokens: 10, + outputTokens: 5, + cacheReadTokens: undefined, + reasoningTokens: undefined, + }, + ]) }) - const stream = handler.createMessage("system prompt", []) - const firstChunk = await stream.next() + it("should handle tool calls in R1 path", async () => { + mockGetModels.mockResolvedValue({ + "deepseek-ai/DeepSeek-R1-0528": { + maxTokens: 32768, + contextWindow: 163840, + supportsImages: false, + supportsPromptCache: false, + }, + }) - expect(firstChunk.done).toBe(false) - expect(firstChunk.value).toEqual({ type: "text", text: testContent }) - }) + async function* mockFullStream() { + yield { type: "text-delta", text: "Let me help" } + yield { + type: "tool-input-start", + id: "call_123", + toolName: "test_tool", + } + yield { + type: "tool-input-delta", + id: "call_123", + delta: '{"arg":"value"}', + } + yield { + type: "tool-input-end", + id: "call_123", + } + } - it("createMessage should yield usage data from stream", async () => { - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - next: vi - .fn() - .mockResolvedValueOnce({ - done: false, - value: { choices: [{ delta: {} }], usage: { prompt_tokens: 10, completion_tokens: 20 } }, - }) - .mockResolvedValueOnce({ done: true }), - }), + const mockUsage = Promise.resolve({ + inputTokens: 15, + outputTokens: 10, + }) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + }) + + const handlerWithModel = new ChutesHandler({ + apiModelId: "deepseek-ai/DeepSeek-R1-0528", + chutesApiKey: "test-key", + }) + + const stream = handlerWithModel.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) } + + expect(chunks).toContainEqual({ type: "text", text: "Let me help" }) + expect(chunks).toContainEqual({ + type: "tool_call_start", + id: "call_123", + name: "test_tool", + }) + expect(chunks).toContainEqual({ + type: "tool_call_delta", + id: "call_123", + delta: '{"arg":"value"}', + }) + expect(chunks).toContainEqual({ + type: "tool_call_end", + id: "call_123", + }) }) - const stream = handler.createMessage("system prompt", []) - const firstChunk = await stream.next() + it("should merge system prompt into first user message for R1 path", async () => { + mockGetModels.mockResolvedValue({ + "deepseek-ai/DeepSeek-R1-0528": { + maxTokens: 32768, + contextWindow: 163840, + supportsImages: false, + supportsPromptCache: false, + }, + }) - expect(firstChunk.done).toBe(false) - expect(firstChunk.value).toEqual({ type: "usage", inputTokens: 10, outputTokens: 20 }) - }) + async function* mockFullStream() { + yield { type: "text-delta", text: "Response" } + } - it("createMessage should yield tool_call_partial from stream", async () => { - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - next: vi - .fn() - .mockResolvedValueOnce({ - done: false, - value: { - choices: [ - { - delta: { - tool_calls: [ - { - index: 0, - id: "call_123", - function: { name: "test_tool", arguments: '{"arg":"value"}' }, - }, - ], - }, - }, - ], - }, - }) - .mockResolvedValueOnce({ done: true }), + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 5, outputTokens: 3 }), + }) + + const handlerWithModel = new ChutesHandler({ + apiModelId: "deepseek-ai/DeepSeek-R1-0528", + chutesApiKey: "test-key", + }) + + const stream = handlerWithModel.createMessage(systemPrompt, messages) + for await (const _ of stream) { + // consume + } + + expect(mockStreamText).toHaveBeenCalledWith( + expect.objectContaining({ + messages: expect.any(Array), }), + ) + + const callArgs = mockStreamText.mock.calls[0][0] + expect(callArgs.system).toBeUndefined() + }) + + it("should pass system prompt separately for non-R1 path", async () => { + mockGetModels.mockResolvedValue({ + "some-model": { maxTokens: 1024, contextWindow: 8192, supportsPromptCache: false }, + }) + + async function* mockFullStream() { + yield { type: "text-delta", text: "Response" } } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 5, outputTokens: 3 }), + }) + + const handlerWithModel = new ChutesHandler({ + apiModelId: "some-model", + chutesApiKey: "test-key", + }) + + const stream = handlerWithModel.createMessage(systemPrompt, messages) + for await (const _ of stream) { + // consume + } + + expect(mockStreamText).toHaveBeenCalledWith( + expect.objectContaining({ + system: systemPrompt, + }), + ) }) - const stream = handler.createMessage("system prompt", []) - const firstChunk = await stream.next() + it("should include usage information from stream", async () => { + mockGetModels.mockResolvedValue({ + "some-model": { maxTokens: 1024, contextWindow: 8192, supportsPromptCache: false }, + }) + + async function* mockFullStream() { + yield { type: "text-delta", text: "Hello" } + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ + inputTokens: 20, + outputTokens: 10, + }), + }) + + const handlerWithModel = new ChutesHandler({ + apiModelId: "some-model", + chutesApiKey: "test-key", + }) - expect(firstChunk.done).toBe(false) - expect(firstChunk.value).toEqual({ - type: "tool_call_partial", - index: 0, - id: "call_123", - name: "test_tool", - arguments: '{"arg":"value"}', + const stream = handlerWithModel.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const usageChunks = chunks.filter((c) => c.type === "usage") + expect(usageChunks).toHaveLength(1) + expect(usageChunks[0].inputTokens).toBe(20) + expect(usageChunks[0].outputTokens).toBe(10) }) }) - it("createMessage should pass tools and tool_choice to API", async () => { - const tools = [ - { - type: "function" as const, - function: { - name: "test_tool", - description: "A test tool", - parameters: { type: "object", properties: {} }, - }, - }, - ] - const tool_choice = "auto" as const - - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - next: vi.fn().mockResolvedValueOnce({ done: true }), + describe("completePrompt", () => { + it("should return text from generateText", async () => { + const expectedResponse = "This is a test response from Chutes" + mockGenerateText.mockResolvedValue({ text: expectedResponse }) + + const result = await handler.completePrompt("test prompt") + expect(result).toBe(expectedResponse) + expect(mockGenerateText).toHaveBeenCalledWith( + expect.objectContaining({ + prompt: "test prompt", }), - } + ) }) - const stream = handler.createMessage("system prompt", [], { tools, tool_choice, taskId: "test-task-id" }) - // Consume stream - for await (const _ of stream) { - // noop - } - - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - tools, - tool_choice, - }), - ) - }) + it("should handle errors in completePrompt", async () => { + const errorMessage = "Chutes API error" + mockGenerateText.mockRejectedValue(new Error(errorMessage)) + await expect(handler.completePrompt("test prompt")).rejects.toThrow( + `Chutes completion error: ${errorMessage}`, + ) + }) - it("should apply DeepSeek default temperature for R1 models", () => { - const testModelId = "deepseek-ai/DeepSeek-R1" - const handlerWithModel = new ChutesHandler({ - apiModelId: testModelId, - chutesApiKey: "test-chutes-api-key", + it("should pass temperature for R1 models in completePrompt", async () => { + mockGetModels.mockResolvedValue({ + "deepseek-ai/DeepSeek-R1-0528": { + maxTokens: 32768, + contextWindow: 163840, + supportsImages: false, + supportsPromptCache: false, + }, + }) + + mockGenerateText.mockResolvedValue({ text: "response" }) + + const handlerWithModel = new ChutesHandler({ + apiModelId: "deepseek-ai/DeepSeek-R1-0528", + chutesApiKey: "test-key", + }) + + await handlerWithModel.completePrompt("test prompt") + + expect(mockGenerateText).toHaveBeenCalledWith( + expect.objectContaining({ + temperature: DEEP_SEEK_DEFAULT_TEMPERATURE, + }), + ) }) - const model = handlerWithModel.getModel() - expect(model.info.temperature).toBe(DEEP_SEEK_DEFAULT_TEMPERATURE) }) - it("should use default temperature for non-DeepSeek models", () => { - const testModelId = "unsloth/Llama-3.3-70B-Instruct" - const handlerWithModel = new ChutesHandler({ - apiModelId: testModelId, - chutesApiKey: "test-chutes-api-key", + describe("isAiSdkProvider", () => { + it("should return true", () => { + expect(handler.isAiSdkProvider()).toBe(true) }) - // Note: getModel() returns fallback default without calling fetchModel - // Since we haven't called fetchModel, it returns the default chutesDefaultModelId - // which is DeepSeek-R1-0528, therefore temperature will be DEEP_SEEK_DEFAULT_TEMPERATURE - const model = handlerWithModel.getModel() - // The default model is DeepSeek-R1, so it returns DEEP_SEEK_DEFAULT_TEMPERATURE - expect(model.info.temperature).toBe(DEEP_SEEK_DEFAULT_TEMPERATURE) }) }) diff --git a/src/api/providers/chutes.ts b/src/api/providers/chutes.ts index 6b040834cd8..66e1d6c9879 100644 --- a/src/api/providers/chutes.ts +++ b/src/api/providers/chutes.ts @@ -1,62 +1,110 @@ -import { DEEP_SEEK_DEFAULT_TEMPERATURE, chutesDefaultModelId, chutesDefaultModelInfo } from "@roo-code/types" import { Anthropic } from "@anthropic-ai/sdk" -import OpenAI from "openai" +import { streamText, generateText, LanguageModel, ToolSet } from "ai" + +import { + DEEP_SEEK_DEFAULT_TEMPERATURE, + chutesDefaultModelId, + chutesDefaultModelInfo, + type ModelInfo, + type ModelRecord, +} from "@roo-code/types" import type { ApiHandlerOptions } from "../../shared/api" import { getModelMaxOutputTokens } from "../../shared/api" import { TagMatcher } from "../../utils/tag-matcher" -import { convertToR1Format } from "../transform/r1-format" -import { convertToOpenAiMessages } from "../transform/openai-format" +import { + convertToAiSdkMessages, + convertToolsForAiSdk, + processAiSdkStreamPart, + mapToolChoice, + handleAiSdkError, +} from "../transform/ai-sdk" import { ApiStream } from "../transform/stream" import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" -import { RouterProvider } from "./router-provider" +import { OpenAICompatibleHandler, OpenAICompatibleConfig } from "./openai-compatible" +import { getModels, getModelsFromCache } from "./fetchers/modelCache" + +export class ChutesHandler extends OpenAICompatibleHandler implements SingleCompletionHandler { + private models: ModelRecord = {} -export class ChutesHandler extends RouterProvider implements SingleCompletionHandler { constructor(options: ApiHandlerOptions) { - super({ - options, - name: "chutes", + const modelId = options.apiModelId ?? chutesDefaultModelId + + const config: OpenAICompatibleConfig = { + providerName: "chutes", baseURL: "https://llm.chutes.ai/v1", - apiKey: options.chutesApiKey, - modelId: options.apiModelId, - defaultModelId: chutesDefaultModelId, - defaultModelInfo: chutesDefaultModelInfo, - }) + apiKey: options.chutesApiKey ?? "not-provided", + modelId, + modelInfo: chutesDefaultModelInfo, + } + + super(options, config) } - private getCompletionParams( - systemPrompt: string, - messages: Anthropic.Messages.MessageParam[], - metadata?: ApiHandlerCreateMessageMetadata, - ): OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming { - const { id: model, info } = this.getModel() + async fetchModel() { + this.models = await getModels({ provider: "chutes", apiKey: this.config.apiKey, baseUrl: this.config.baseURL }) + return this.getModel() + } + + override getModel(): { id: string; info: ModelInfo; temperature?: number } { + const id = this.options.apiModelId ?? chutesDefaultModelId + + let info: ModelInfo | undefined = this.models[id] - // Centralized cap: clamp to 20% of the context window (unless provider-specific exceptions apply) - const max_tokens = + if (!info) { + const cachedModels = getModelsFromCache("chutes") + if (cachedModels?.[id]) { + this.models = cachedModels + info = cachedModels[id] + } + } + + if (!info) { + const isDeepSeekR1 = chutesDefaultModelId.includes("DeepSeek-R1") + const defaultTemp = isDeepSeekR1 ? DEEP_SEEK_DEFAULT_TEMPERATURE : 0.5 + return { + id: chutesDefaultModelId, + info: { + ...chutesDefaultModelInfo, + defaultTemperature: defaultTemp, + }, + temperature: this.options.modelTemperature ?? defaultTemp, + } + } + + const isDeepSeekR1 = id.includes("DeepSeek-R1") + const defaultTemp = isDeepSeekR1 ? DEEP_SEEK_DEFAULT_TEMPERATURE : 0.5 + + return { + id, + info: { + ...info, + defaultTemperature: defaultTemp, + }, + temperature: this.supportsTemperature(id) ? (this.options.modelTemperature ?? defaultTemp) : undefined, + } + } + + protected override getLanguageModel(): LanguageModel { + const { id } = this.getModel() + return this.provider(id) + } + + protected override getMaxOutputTokens(): number | undefined { + const { id, info } = this.getModel() + return ( getModelMaxOutputTokens({ - modelId: model, + modelId: id, model: info, settings: this.options, format: "openai", }) ?? undefined + ) + } - const params: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = { - model, - max_tokens, - messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)], - stream: true, - stream_options: { include_usage: true }, - tools: metadata?.tools, - tool_choice: metadata?.tool_choice, - } - - // Only add temperature if model supports it - if (this.supportsTemperature(model)) { - params.temperature = this.options.modelTemperature ?? info.temperature - } - - return params + private supportsTemperature(modelId: string): boolean { + return !modelId.startsWith("openai/o3-mini") } override async *createMessage( @@ -67,125 +115,123 @@ export class ChutesHandler extends RouterProvider implements SingleCompletionHan const model = await this.fetchModel() if (model.id.includes("DeepSeek-R1")) { - const stream = await this.client.chat.completions.create({ - ...this.getCompletionParams(systemPrompt, messages, metadata), - messages: convertToR1Format([{ role: "user", content: systemPrompt }, ...messages]), - }) + yield* this.createR1Message(systemPrompt, messages, model, metadata) + } else { + yield* super.createMessage(systemPrompt, messages, metadata) + } + } - const matcher = new TagMatcher( - "think", - (chunk) => - ({ - type: chunk.matched ? "reasoning" : "text", - text: chunk.data, - }) as const, - ) + private async *createR1Message( + systemPrompt: string, + messages: Anthropic.Messages.MessageParam[], + model: { id: string; info: ModelInfo }, + metadata?: ApiHandlerCreateMessageMetadata, + ): ApiStream { + const languageModel = this.getLanguageModel() + + const modifiedMessages = [...messages] as Anthropic.Messages.MessageParam[] + + if (modifiedMessages.length > 0 && modifiedMessages[0].role === "user") { + const first = modifiedMessages[0] + if (typeof first.content === "string") { + modifiedMessages[0] = { role: "user", content: `${systemPrompt}\n\n${first.content}` } + } else { + modifiedMessages[0] = { + role: "user", + content: [{ type: "text", text: systemPrompt }, ...first.content], + } + } + } else { + modifiedMessages.unshift({ role: "user", content: systemPrompt }) + } - for await (const chunk of stream) { - const delta = chunk.choices[0]?.delta + const aiSdkMessages = convertToAiSdkMessages(modifiedMessages) - if (delta?.content) { - for (const processedChunk of matcher.update(delta.content)) { - yield processedChunk - } - } + const openAiTools = this.convertToolsForOpenAI(metadata?.tools) + const aiSdkTools = convertToolsForAiSdk(openAiTools) as ToolSet | undefined - // Emit raw tool call chunks - NativeToolCallParser handles state management - if (delta && "tool_calls" in delta && Array.isArray(delta.tool_calls)) { - for (const toolCall of delta.tool_calls) { - yield { - type: "tool_call_partial", - index: toolCall.index, - id: toolCall.id, - name: toolCall.function?.name, - arguments: toolCall.function?.arguments, - } - } - } + const maxOutputTokens = + getModelMaxOutputTokens({ + modelId: model.id, + model: model.info, + settings: this.options, + format: "openai", + }) ?? undefined + + const temperature = this.supportsTemperature(model.id) + ? (this.options.modelTemperature ?? model.info.defaultTemperature) + : undefined + + const result = streamText({ + model: languageModel, + messages: aiSdkMessages, + temperature, + maxOutputTokens, + tools: aiSdkTools, + toolChoice: mapToolChoice(metadata?.tool_choice), + }) + + const matcher = new TagMatcher( + "think", + (chunk) => + ({ + type: chunk.matched ? "reasoning" : "text", + text: chunk.data, + }) as const, + ) - if (chunk.usage) { - yield { - type: "usage", - inputTokens: chunk.usage.prompt_tokens || 0, - outputTokens: chunk.usage.completion_tokens || 0, + try { + for await (const part of result.fullStream) { + if (part.type === "text-delta") { + for (const processedChunk of matcher.update(part.text)) { + yield processedChunk + } + } else { + for (const chunk of processAiSdkStreamPart(part)) { + yield chunk } } } - // Process any remaining content for (const processedChunk of matcher.final()) { yield processedChunk } - } else { - // For non-DeepSeek-R1 models, use standard OpenAI streaming - const stream = await this.client.chat.completions.create( - this.getCompletionParams(systemPrompt, messages, metadata), - ) - - for await (const chunk of stream) { - const delta = chunk.choices[0]?.delta - - if (delta?.content) { - yield { type: "text", text: delta.content } - } - - if (delta && "reasoning_content" in delta && delta.reasoning_content) { - yield { type: "reasoning", text: (delta.reasoning_content as string | undefined) || "" } - } - - // Emit raw tool call chunks - NativeToolCallParser handles state management - if (delta && "tool_calls" in delta && Array.isArray(delta.tool_calls)) { - for (const toolCall of delta.tool_calls) { - yield { - type: "tool_call_partial", - index: toolCall.index, - id: toolCall.id, - name: toolCall.function?.name, - arguments: toolCall.function?.arguments, - } - } - } - if (chunk.usage) { - yield { - type: "usage", - inputTokens: chunk.usage.prompt_tokens || 0, - outputTokens: chunk.usage.completion_tokens || 0, - } - } + const usage = await result.usage + if (usage) { + yield this.processUsageMetrics(usage) } + } catch (error) { + throw handleAiSdkError(error, "chutes") } } - async completePrompt(prompt: string): Promise { + override async completePrompt(prompt: string): Promise { const model = await this.fetchModel() - const { id: modelId, info } = model + const languageModel = this.getLanguageModel() - try { - // Centralized cap: clamp to 20% of the context window (unless provider-specific exceptions apply) - const max_tokens = - getModelMaxOutputTokens({ - modelId, - model: info, - settings: this.options, - format: "openai", - }) ?? undefined - - const requestParams: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = { - model: modelId, - messages: [{ role: "user", content: prompt }], - max_tokens, - } + const maxOutputTokens = + getModelMaxOutputTokens({ + modelId: model.id, + model: model.info, + settings: this.options, + format: "openai", + }) ?? undefined - // Only add temperature if model supports it - if (this.supportsTemperature(modelId)) { - const isDeepSeekR1 = modelId.includes("DeepSeek-R1") - const defaultTemperature = isDeepSeekR1 ? DEEP_SEEK_DEFAULT_TEMPERATURE : 0.5 - requestParams.temperature = this.options.modelTemperature ?? defaultTemperature - } + const isDeepSeekR1 = model.id.includes("DeepSeek-R1") + const defaultTemperature = isDeepSeekR1 ? DEEP_SEEK_DEFAULT_TEMPERATURE : 0.5 + const temperature = this.supportsTemperature(model.id) + ? (this.options.modelTemperature ?? defaultTemperature) + : undefined - const response = await this.client.chat.completions.create(requestParams) - return response.choices[0]?.message.content || "" + try { + const { text } = await generateText({ + model: languageModel, + prompt, + maxOutputTokens, + temperature, + }) + return text } catch (error) { if (error instanceof Error) { throw new Error(`Chutes completion error: ${error.message}`) @@ -193,17 +239,4 @@ export class ChutesHandler extends RouterProvider implements SingleCompletionHan throw error } } - - override getModel() { - const model = super.getModel() - const isDeepSeekR1 = model.id.includes("DeepSeek-R1") - - return { - ...model, - info: { - ...model.info, - temperature: isDeepSeekR1 ? DEEP_SEEK_DEFAULT_TEMPERATURE : 0.5, - }, - } - } }