diff --git a/e2e/react-start/server-functions/src/routes/factory/-functions/functions.ts b/e2e/react-start/server-functions/src/routes/factory/-functions/functions.ts index 05dd8f46bd9..c5df9d71653 100644 --- a/e2e/react-start/server-functions/src/routes/factory/-functions/functions.ts +++ b/e2e/react-start/server-functions/src/routes/factory/-functions/functions.ts @@ -6,6 +6,8 @@ import { createFakeFn } from './createFakeFn' import { reexportFactory } from './reexportIndex' // Test star re-export syntax: `export * from './module'` import { starReexportFactory } from './starReexportIndex' +// Test nested star re-export syntax: A -> B -> C chain +import { nestedReexportFactory } from './nestedReexportA' export const fooFn = createFooServerFn().handler(({ context }) => { return { @@ -115,3 +117,14 @@ export const starReexportedFactoryFn = starReexportFactory().handler( } }, ) + +// Test that nested star re-exported factories (A -> B -> C chain) work correctly +// The middleware from nestedReexportFactory should execute and add { nested: 'nested-middleware-executed' } to context +export const nestedReexportedFactoryFn = nestedReexportFactory().handler( + ({ context }) => { + return { + name: 'nestedReexportedFactoryFn', + context, + } + }, +) diff --git a/e2e/react-start/server-functions/src/routes/factory/-functions/nestedReexportA.ts b/e2e/react-start/server-functions/src/routes/factory/-functions/nestedReexportA.ts new file mode 100644 index 00000000000..dd562a29a7c --- /dev/null +++ b/e2e/react-start/server-functions/src/routes/factory/-functions/nestedReexportA.ts @@ -0,0 +1,7 @@ +/** + * Top-level module in the nested re-export chain. + * Re-exports everything from nestedReexportB. + * + * Chain: nestedReexportA (this file) -> nestedReexportB -> nestedReexportC + */ +export * from './nestedReexportB' diff --git a/e2e/react-start/server-functions/src/routes/factory/-functions/nestedReexportB.ts b/e2e/react-start/server-functions/src/routes/factory/-functions/nestedReexportB.ts new file mode 100644 index 00000000000..07f1a8d8b5d --- /dev/null +++ b/e2e/react-start/server-functions/src/routes/factory/-functions/nestedReexportB.ts @@ -0,0 +1,7 @@ +/** + * Middle module in the nested re-export chain. + * Re-exports everything from nestedReexportC. + * + * Chain: nestedReexportA -> nestedReexportB (this file) -> nestedReexportC + */ +export * from './nestedReexportC' diff --git a/e2e/react-start/server-functions/src/routes/factory/-functions/nestedReexportC.ts b/e2e/react-start/server-functions/src/routes/factory/-functions/nestedReexportC.ts new file mode 100644 index 00000000000..209e7d74825 --- /dev/null +++ b/e2e/react-start/server-functions/src/routes/factory/-functions/nestedReexportC.ts @@ -0,0 +1,20 @@ +/** + * This is the deepest module in the nested re-export chain. + * It defines a server function factory with middleware. + * + * Chain: nestedReexportA -> nestedReexportB -> nestedReexportC (this file) + */ +import { createMiddleware, createServerFn } from '@tanstack/react-start' + +const nestedMiddleware = createMiddleware({ type: 'function' }).server( + ({ next }) => { + console.log('nested middleware triggered') + return next({ + context: { nested: 'nested-middleware-executed' } as const, + }) + }, +) + +export const nestedReexportFactory = createServerFn({ + method: 'GET', +}).middleware([nestedMiddleware]) diff --git a/e2e/react-start/server-functions/src/routes/factory/index.tsx b/e2e/react-start/server-functions/src/routes/factory/index.tsx index 8be82b597fb..76d0450a15d 100644 --- a/e2e/react-start/server-functions/src/routes/factory/index.tsx +++ b/e2e/react-start/server-functions/src/routes/factory/index.tsx @@ -12,6 +12,7 @@ import { fooFnPOST, localFn, localFnPOST, + nestedReexportedFactoryFn, reexportedFactoryFn, starReexportedFactoryFn, } from './-functions/functions' @@ -152,6 +153,16 @@ const functions = { context: { starReexport: 'star-reexport-middleware-executed' }, }, }, + // Test that nested star re-exported factories (A -> B -> C chain) work correctly + // The middleware from nestedReexportFactory should execute and add { nested: 'nested-middleware-executed' } to context + nestedReexportedFactoryFn: { + fn: nestedReexportedFactoryFn, + type: 'serverFn', + expected: { + name: 'nestedReexportedFactoryFn', + context: { nested: 'nested-middleware-executed' }, + }, + }, } satisfies Record interface TestCase { diff --git a/e2e/react-start/server-functions/tests/server-functions.spec.ts b/e2e/react-start/server-functions/tests/server-functions.spec.ts index 0b56c3dbc8d..614e6d6b59e 100644 --- a/e2e/react-start/server-functions/tests/server-functions.spec.ts +++ b/e2e/react-start/server-functions/tests/server-functions.spec.ts @@ -591,3 +591,27 @@ test('star re-exported server function factory middleware executes correctly', a page.getByTestId('fn-comparison-starReexportedFactoryFn'), ).toContainText('equal') }) + +test('nested star re-exported server function factory middleware executes correctly', async ({ + page, +}) => { + // This test specifically verifies that when a server function factory is re-exported + // through a nested chain (A -> B -> C) using `export * from './module'` syntax, + // the middleware still executes correctly. + await page.goto('/factory') + + await expect(page.getByTestId('factory-route-component')).toBeInViewport() + + // Click the button for the nested re-exported factory function + await page.getByTestId('btn-fn-nestedReexportedFactoryFn').click() + + // Wait for the result + await expect( + page.getByTestId('fn-result-nestedReexportedFactoryFn'), + ).toContainText('nested-middleware-executed') + + // Verify the full context was returned (middleware executed) + await expect( + page.getByTestId('fn-comparison-nestedReexportedFactoryFn'), + ).toContainText('equal') +}) diff --git a/packages/start-plugin-core/src/create-server-fn-plugin/compiler.ts b/packages/start-plugin-core/src/create-server-fn-plugin/compiler.ts index 84212206eaf..f201d17a794 100644 --- a/packages/start-plugin-core/src/create-server-fn-plugin/compiler.ts +++ b/packages/start-plugin-core/src/create-server-fn-plugin/compiler.ts @@ -8,6 +8,7 @@ import { } from 'babel-dead-code-elimination' import { handleCreateServerFn } from './handleCreateServerFn' import { handleCreateMiddleware } from './handleCreateMiddleware' +import type { MethodChainPaths, RewriteCandidate } from './types' type Binding = | { @@ -41,6 +42,16 @@ const LookupSetup: Record< }, } +// Pre-computed map: identifier name -> LookupKind for fast candidate detection +const IdentifierToKind = new Map() +for (const [kind, setup] of Object.entries(LookupSetup) as Array< + [LookupKind, { candidateCallIdentifier: Set }] +>) { + for (const id of setup.candidateCallIdentifier) { + IdentifierToKind.set(id, kind) + } +} + export type LookupConfig = { libName: string rootExport: string @@ -59,6 +70,10 @@ export class ServerFnCompiler { private moduleCache = new Map() private initialized = false private validLookupKinds: Set + // Fast lookup for direct imports from known libraries (e.g., '@tanstack/react-start') + // Maps: libName → (exportName → Kind) + // This allows O(1) resolution for the common case without async resolveId calls + private knownRootImports = new Map>() constructor( private options: { env: 'client' | 'server' @@ -108,6 +123,14 @@ export class ServerFnCompiler { resolvedKind: `Root` satisfies Kind, }) this.moduleCache.set(libId, rootModule) + + // Also populate the fast lookup map for direct imports + let libExports = this.knownRootImports.get(config.libName) + if (!libExports) { + libExports = new Map() + this.knownRootImports.set(config.libName, libExports) + } + libExports.set(config.rootExport, 'Root') }), ) @@ -247,36 +270,98 @@ export class ServerFnCompiler { } // let's find out which of the candidates are actually server functions - const toRewrite: Array<{ - callExpression: t.CallExpression - kind: LookupKind - }> = [] - for (const handler of candidates) { - const kind = await this.resolveExprKind(handler, id) + // Resolve all candidates in parallel for better performance + const resolvedCandidates = await Promise.all( + candidates.map(async (candidate) => ({ + candidate, + kind: await this.resolveExprKind(candidate, id), + })), + ) + + // Map from candidate/root node -> kind + // Note: For top-level variable declarations, candidate === root (the outermost CallExpression) + const toRewriteMap = new Map() + for (const { candidate, kind } of resolvedCandidates) { if (this.validLookupKinds.has(kind as LookupKind)) { - toRewrite.push({ callExpression: handler, kind: kind as LookupKind }) + toRewriteMap.set(candidate, kind as LookupKind) } } - if (toRewrite.length === 0) { + if (toRewriteMap.size === 0) { return null } + // Single-pass traversal to find NodePaths and collect method chains const pathsToRewrite: Array<{ - nodePath: babel.NodePath + path: babel.NodePath kind: LookupKind + methodChain: MethodChainPaths }> = [] + + // First, collect all CallExpression paths in the AST for O(1) lookup + const callExprPaths = new Map< + t.CallExpression, + babel.NodePath + >() + babel.traverse(ast, { CallExpression(path) { - const found = toRewrite.findIndex((h) => path.node === h.callExpression) - if (found !== -1) { - pathsToRewrite.push({ nodePath: path, kind: toRewrite[found]!.kind }) - // delete from toRewrite - toRewrite.splice(found, 1) - } + callExprPaths.set(path.node, path) }, }) - if (toRewrite.length > 0) { + // Now process candidates - we can look up any CallExpression path in O(1) + for (const [node, kind] of toRewriteMap) { + const path = callExprPaths.get(node) + if (!path) { + continue + } + + // Collect method chain paths by walking DOWN from root through the chain + const methodChain: MethodChainPaths = { + middleware: null, + inputValidator: null, + handler: null, + server: null, + client: null, + } + + // Walk down the call chain using nodes, look up paths from map + let currentNode: t.CallExpression = node + // eslint-disable-next-line @typescript-eslint/no-unnecessary-condition + while (true) { + const callee = currentNode.callee + if (!t.isMemberExpression(callee)) { + break + } + + // Record method chain path if it's a known method + if (t.isIdentifier(callee.property)) { + const name = callee.property.name as keyof MethodChainPaths + if (name in methodChain) { + const currentPath = callExprPaths.get(currentNode)! + // Get first argument path + const args = currentPath.get('arguments') + const firstArgPath = + Array.isArray(args) && args.length > 0 ? (args[0] ?? null) : null + methodChain[name] = { + callPath: currentPath, + firstArgPath, + } + } + } + + // Move to the inner call (the object of the member expression) + if (!t.isCallExpression(callee.object)) { + break + } + currentNode = callee.object + } + + pathsToRewrite.push({ path, kind, methodChain }) + } + + // Verify we found all candidates (pathsToRewrite should have same size as toRewriteMap had) + if (pathsToRewrite.length !== toRewriteMap.size) { throw new Error( `Internal error: could not find all paths to rewrite. please file an issue`, ) @@ -284,18 +369,21 @@ export class ServerFnCompiler { const refIdents = findReferencedIdentifiers(ast) - pathsToRewrite.map((p) => { - if (p.kind === 'ServerFn') { - handleCreateServerFn(p.nodePath, { + for (const { path, kind, methodChain } of pathsToRewrite) { + const candidate: RewriteCandidate = { path, methodChain } + if (kind === 'ServerFn') { + handleCreateServerFn(candidate, { env: this.options.env, code, directive: this.options.directive, isProviderFile, }) } else { - handleCreateMiddleware(p.nodePath, { env: this.options.env }) + handleCreateMiddleware(candidate, { + env: this.options.env, + }) } - }) + } deadCodeElimination(ast, refIdents) @@ -312,12 +400,12 @@ export class ServerFnCompiler { for (const binding of bindings.values()) { if (binding.type === 'var') { - const handler = isCandidateCallExpression( + const candidate = isCandidateCallExpression( binding.init, this.validLookupKinds, ) - if (handler) { - candidates.push(handler) + if (candidate) { + candidates.push(candidate) } } } @@ -352,6 +440,61 @@ export class ServerFnCompiler { return resolvedKind } + /** + * Recursively find an export in a module, following `export * from` chains. + * Returns the module info and binding if found, or undefined if not found. + */ + private async findExportInModule( + moduleInfo: ModuleInfo, + exportName: string, + visitedModules = new Set(), + ): Promise<{ moduleInfo: ModuleInfo; binding: Binding } | undefined> { + // Prevent infinite loops in circular re-exports + if (visitedModules.has(moduleInfo.id)) { + return undefined + } + visitedModules.add(moduleInfo.id) + + // First check direct exports + const directExport = moduleInfo.exports.get(exportName) + if (directExport) { + const binding = moduleInfo.bindings.get(directExport.name) + if (binding) { + return { moduleInfo, binding } + } + } + + // If not found, recursively check re-export-all sources in parallel + // Valid code won't have duplicate exports across chains, so first match wins + if (moduleInfo.reExportAllSources.length > 0) { + const results = await Promise.all( + moduleInfo.reExportAllSources.map(async (reExportSource) => { + const reExportTarget = await this.options.resolveId( + reExportSource, + moduleInfo.id, + ) + if (reExportTarget) { + const reExportModule = await this.getModuleInfo(reExportTarget) + return this.findExportInModule( + reExportModule, + exportName, + visitedModules, + ) + } + return undefined + }), + ) + // Return the first valid result + for (const result of results) { + if (result) { + return result + } + } + } + + return undefined + } + private async resolveBindingKind( binding: Binding, fileId: string, @@ -361,6 +504,19 @@ export class ServerFnCompiler { return binding.resolvedKind } if (binding.type === 'import') { + // Fast path: check if this is a direct import from a known library + // (e.g., import { createServerFn } from '@tanstack/react-start') + // This avoids async resolveId calls for the common case + const knownExports = this.knownRootImports.get(binding.source) + if (knownExports) { + const kind = knownExports.get(binding.importedName) + if (kind) { + binding.resolvedKind = kind + return kind + } + } + + // Slow path: resolve through the module graph const target = await this.options.resolveId(binding.source, fileId) if (!target) { return 'None' @@ -368,60 +524,28 @@ export class ServerFnCompiler { const importedModule = await this.getModuleInfo(target) - // Try to find the export in the module's direct exports - const moduleExport = importedModule.exports.get(binding.importedName) - - // If not found directly, check re-export-all sources (`export * from './module'`) - if (!moduleExport && importedModule.reExportAllSources.length > 0) { - for (const reExportSource of importedModule.reExportAllSources) { - const reExportTarget = await this.options.resolveId( - reExportSource, - importedModule.id, - ) - if (reExportTarget) { - const reExportModule = await this.getModuleInfo(reExportTarget) - const reExportEntry = reExportModule.exports.get( - binding.importedName, - ) - if (reExportEntry) { - // Found the export in a re-exported module, resolve from there - const reExportBinding = reExportModule.bindings.get( - reExportEntry.name, - ) - if (reExportBinding) { - if (reExportBinding.resolvedKind) { - return reExportBinding.resolvedKind - } - const resolvedKind = await this.resolveBindingKind( - reExportBinding, - reExportModule.id, - visited, - ) - reExportBinding.resolvedKind = resolvedKind - return resolvedKind - } - } - } - } - } + // Find the export, recursively searching through export * from chains + const found = await this.findExportInModule( + importedModule, + binding.importedName, + ) - if (!moduleExport) { + if (!found) { return 'None' } - const importedBinding = importedModule.bindings.get(moduleExport.name) - if (!importedBinding) { - return 'None' - } - if (importedBinding.resolvedKind) { - return importedBinding.resolvedKind + + const { moduleInfo: foundModule, binding: foundBinding } = found + + if (foundBinding.resolvedKind) { + return foundBinding.resolvedKind } const resolvedKind = await this.resolveBindingKind( - importedBinding, - importedModule.id, + foundBinding, + foundModule.id, visited, ) - importedBinding.resolvedKind = resolvedKind + foundBinding.resolvedKind = resolvedKind return resolvedKind } @@ -443,6 +567,15 @@ export class ServerFnCompiler { return 'None' } + // Unwrap common TypeScript/parenthesized wrappers first for efficiency + while ( + t.isTSAsExpression(expr) || + t.isTSNonNullExpression(expr) || + t.isParenthesizedExpression(expr) + ) { + expr = expr.expression + } + let result: Kind = 'None' if (t.isCallExpression(expr)) { @@ -454,15 +587,12 @@ export class ServerFnCompiler { fileId, visited, ) - if (calleeKind !== 'None') { - if (calleeKind === `Root` || calleeKind === `Builder`) { - return `Builder` - } - for (const kind of this.validLookupKinds) { - if (calleeKind === kind) { - return kind - } - } + if (calleeKind === 'Root' || calleeKind === 'Builder') { + return 'Builder' + } + // Use direct Set.has() instead of iterating + if (this.validLookupKinds.has(calleeKind as LookupKind)) { + return calleeKind } } else if (t.isMemberExpression(expr) && t.isIdentifier(expr.property)) { result = await this.resolveCalleeKind(expr.object, fileId, visited) @@ -472,16 +602,6 @@ export class ServerFnCompiler { result = await this.resolveIdentifierKind(expr.name, fileId, visited) } - if (result === 'None' && t.isTSAsExpression(expr)) { - result = await this.resolveExprKind(expr.expression, fileId, visited) - } - if (result === 'None' && t.isTSNonNullExpression(expr)) { - result = await this.resolveExprKind(expr.expression, fileId, visited) - } - if (result === 'None' && t.isParenthesizedExpression(expr)) { - result = await this.resolveExprKind(expr.expression, fileId, visited) - } - return result } @@ -576,17 +696,18 @@ export class ServerFnCompiler { function isCandidateCallExpression( node: t.Node | null | undefined, lookupKinds: Set, -): undefined | t.CallExpression { +): t.CallExpression | undefined { if (!t.isCallExpression(node)) return undefined const callee = node.callee if (!t.isMemberExpression(callee) || !t.isIdentifier(callee.property)) { return undefined } - for (const kind of lookupKinds) { - if (LookupSetup[kind].candidateCallIdentifier.has(callee.property.name)) { - return node - } + + // Use pre-computed map for O(1) lookup instead of iterating over lookupKinds + const kind = IdentifierToKind.get(callee.property.name) + if (kind && lookupKinds.has(kind)) { + return node } return undefined diff --git a/packages/start-plugin-core/src/create-server-fn-plugin/handleCreateMiddleware.ts b/packages/start-plugin-core/src/create-server-fn-plugin/handleCreateMiddleware.ts index 60fea360bc2..25c23d0fb1c 100644 --- a/packages/start-plugin-core/src/create-server-fn-plugin/handleCreateMiddleware.ts +++ b/packages/start-plugin-core/src/create-server-fn-plugin/handleCreateMiddleware.ts @@ -1,9 +1,14 @@ import * as t from '@babel/types' -import { getRootCallExpression } from '../start-compiler-plugin/utils' -import type * as babel from '@babel/core' - +import type { RewriteCandidate } from './types' + +/** + * Handles createMiddleware transformations. + * + * @param candidate - The rewrite candidate containing path and method chain + * @param opts - Options including the environment + */ export function handleCreateMiddleware( - path: babel.NodePath, + candidate: RewriteCandidate, opts: { env: 'client' | 'server' }, @@ -11,36 +16,11 @@ export function handleCreateMiddleware( if (opts.env === 'server') { throw new Error('handleCreateMiddleware should not be called on the server') } - const rootCallExpression = getRootCallExpression(path) - - const callExpressionPaths = { - middleware: null as babel.NodePath | null, - inputValidator: null as babel.NodePath | null, - client: null as babel.NodePath | null, - server: null as babel.NodePath | null, - } - - const validMethods = Object.keys(callExpressionPaths) - rootCallExpression.traverse({ - MemberExpression(memberExpressionPath) { - if (t.isIdentifier(memberExpressionPath.node.property)) { - const name = memberExpressionPath.node.property - .name as keyof typeof callExpressionPaths + const { inputValidator, server } = candidate.methodChain - if ( - validMethods.includes(name) && - memberExpressionPath.parentPath.isCallExpression() - ) { - callExpressionPaths[name] = memberExpressionPath.parentPath - } - } - }, - }) - - if (callExpressionPaths.inputValidator) { - const innerInputExpression = - callExpressionPaths.inputValidator.node.arguments[0] + if (inputValidator) { + const innerInputExpression = inputValidator.callPath.node.arguments[0] if (!innerInputExpression) { throw new Error( @@ -49,23 +29,17 @@ export function handleCreateMiddleware( } // remove the validator call expression - if (t.isMemberExpression(callExpressionPaths.inputValidator.node.callee)) { - callExpressionPaths.inputValidator.replaceWith( - callExpressionPaths.inputValidator.node.callee.object, + if (t.isMemberExpression(inputValidator.callPath.node.callee)) { + inputValidator.callPath.replaceWith( + inputValidator.callPath.node.callee.object, ) } } - const serverFnPath = callExpressionPaths.server?.get( - 'arguments.0', - ) as babel.NodePath - - if (callExpressionPaths.server && serverFnPath.node) { + if (server) { // remove the server call expression - if (t.isMemberExpression(callExpressionPaths.server.node.callee)) { - callExpressionPaths.server.replaceWith( - callExpressionPaths.server.node.callee.object, - ) + if (t.isMemberExpression(server.callPath.node.callee)) { + server.callPath.replaceWith(server.callPath.node.callee.object) } } } diff --git a/packages/start-plugin-core/src/create-server-fn-plugin/handleCreateServerFn.ts b/packages/start-plugin-core/src/create-server-fn-plugin/handleCreateServerFn.ts index 2ccbf450d8a..8837e93c7ad 100644 --- a/packages/start-plugin-core/src/create-server-fn-plugin/handleCreateServerFn.ts +++ b/packages/start-plugin-core/src/create-server-fn-plugin/handleCreateServerFn.ts @@ -1,12 +1,15 @@ import * as t from '@babel/types' -import { - codeFrameError, - getRootCallExpression, -} from '../start-compiler-plugin/utils' -import type * as babel from '@babel/core' - +import { codeFrameError } from '../start-compiler-plugin/utils' +import type { RewriteCandidate } from './types' + +/** + * Handles createServerFn transformations. + * + * @param candidate - The rewrite candidate containing path and method chain + * @param opts - Options including the environment, code, directive, and provider file flag + */ export function handleCreateServerFn( - path: babel.NodePath, + candidate: RewriteCandidate, opts: { env: 'client' | 'server' code: string @@ -18,56 +21,27 @@ export function handleCreateServerFn( isProviderFile: boolean }, ) { - // Traverse the member expression and find the call expressions for - // the validator, handler, and middleware methods. Check to make sure they - // are children of the createServerFn call expression. - - const validMethods = ['middleware', 'inputValidator', 'handler'] as const - type ValidMethods = (typeof validMethods)[number] - const callExpressionPaths: Record< - ValidMethods, - babel.NodePath | null - > = { - middleware: null, - inputValidator: null, - handler: null, - } - - const rootCallExpression = getRootCallExpression(path) - - // if (debug) - // console.info( - // 'Handling createServerFn call expression:', - // rootCallExpression.toString(), - // ) + const { path, methodChain } = candidate + const { inputValidator, handler } = methodChain // Check if the call is assigned to a variable - if (!rootCallExpression.parentPath.isVariableDeclarator()) { + if (!path.parentPath.isVariableDeclarator()) { throw new Error('createServerFn must be assigned to a variable!') } // Get the identifier name of the variable - const variableDeclarator = rootCallExpression.parentPath.node - const existingVariableName = (variableDeclarator.id as t.Identifier).name - - rootCallExpression.traverse({ - MemberExpression(memberExpressionPath) { - if (t.isIdentifier(memberExpressionPath.node.property)) { - const name = memberExpressionPath.node.property.name as ValidMethods - - if ( - validMethods.includes(name) && - memberExpressionPath.parentPath.isCallExpression() - ) { - callExpressionPaths[name] = memberExpressionPath.parentPath - } - } - }, - }) + const variableDeclarator = path.parentPath.node + if (!t.isIdentifier(variableDeclarator.id)) { + throw codeFrameError( + opts.code, + variableDeclarator.id.loc!, + 'createServerFn must be assigned to a simple identifier, not a destructuring pattern', + ) + } + const existingVariableName = variableDeclarator.id.name - if (callExpressionPaths.inputValidator) { - const innerInputExpression = - callExpressionPaths.inputValidator.node.arguments[0] + if (inputValidator) { + const innerInputExpression = inputValidator.callPath.node.arguments[0] if (!innerInputExpression) { throw new Error( @@ -77,11 +51,9 @@ export function handleCreateServerFn( // If we're on the client, remove the validator call expression if (opts.env === 'client') { - if ( - t.isMemberExpression(callExpressionPaths.inputValidator.node.callee) - ) { - callExpressionPaths.inputValidator.replaceWith( - callExpressionPaths.inputValidator.node.callee.object, + if (t.isMemberExpression(inputValidator.callPath.node.callee)) { + inputValidator.callPath.replaceWith( + inputValidator.callPath.node.callee.object, ) } } @@ -90,11 +62,9 @@ export function handleCreateServerFn( // First, we need to move the handler function to a nested function call // that is applied to the arguments passed to the server function. - const handlerFnPath = callExpressionPaths.handler?.get( - 'arguments.0', - ) as babel.NodePath + const handlerFnPath = handler?.firstArgPath - if (!callExpressionPaths.handler || !handlerFnPath.node) { + if (!handler || !handlerFnPath?.node) { throw codeFrameError( opts.code, path.node.callee.loc!, @@ -102,6 +72,15 @@ export function handleCreateServerFn( ) } + // Validate the handler argument is an expression (not a SpreadElement, etc.) + if (!t.isExpression(handlerFnPath.node)) { + throw codeFrameError( + opts.code, + handlerFnPath.node.loc!, + `handler() must be called with an expression, not a ${handlerFnPath.node.type}`, + ) + } + const handlerFn = handlerFnPath.node // So, the way we do this is we give the handler function a way @@ -161,6 +140,6 @@ export function handleCreateServerFn( // Caller files must NOT have the second argument because the implementation is already available in the extracted chunk // and including it would duplicate code if (opts.env === 'server' && opts.isProviderFile) { - callExpressionPaths.handler.node.arguments.push(handlerFn) + handler.callPath.node.arguments.push(handlerFn) } } diff --git a/packages/start-plugin-core/src/create-server-fn-plugin/types.ts b/packages/start-plugin-core/src/create-server-fn-plugin/types.ts new file mode 100644 index 00000000000..598c11c8c72 --- /dev/null +++ b/packages/start-plugin-core/src/create-server-fn-plugin/types.ts @@ -0,0 +1,42 @@ +import type * as babel from '@babel/core' +import type * as t from '@babel/types' + +/** + * Info about a method call in the chain, including the call expression path + * and the path to its first argument (if any). + */ +export interface MethodCallInfo { + callPath: babel.NodePath + /** Path to the first argument, or null if no arguments */ + firstArgPath: babel.NodePath | null +} + +/** + * Pre-collected method chain paths for a root call expression. + * This avoids needing to traverse the AST again in handlers. + */ +export interface MethodChainPaths { + middleware: MethodCallInfo | null + inputValidator: MethodCallInfo | null + handler: MethodCallInfo | null + server: MethodCallInfo | null + client: MethodCallInfo | null +} + +export type MethodChainKey = keyof MethodChainPaths + +export const METHOD_CHAIN_KEYS: ReadonlyArray = [ + 'middleware', + 'inputValidator', + 'handler', + 'server', + 'client', +] as const + +/** + * Information about a candidate that needs to be rewritten. + */ +export interface RewriteCandidate { + path: babel.NodePath + methodChain: MethodChainPaths +} diff --git a/packages/start-plugin-core/tests/createMiddleware-create-server-fn-plugin/createMiddleware.test.ts b/packages/start-plugin-core/tests/createMiddleware-create-server-fn-plugin/createMiddleware.test.ts index 337edffecf9..e26d3abeda5 100644 --- a/packages/start-plugin-core/tests/createMiddleware-create-server-fn-plugin/createMiddleware.test.ts +++ b/packages/start-plugin-core/tests/createMiddleware-create-server-fn-plugin/createMiddleware.test.ts @@ -1,6 +1,6 @@ import { readFile, readdir } from 'node:fs/promises' import path from 'node:path' -import { describe, expect, test } from 'vitest' +import { describe, expect, test, vi } from 'vitest' import { ServerFnCompiler } from '../../src/create-server-fn-plugin/compiler' async function getFilenames() { @@ -59,4 +59,99 @@ describe('createMiddleware compiles correctly', async () => { ) }) }) + + test('should use fast path for direct imports from known library (no extra resolveId calls)', async () => { + const code = ` + import { createMiddleware } from '@tanstack/react-start' + const myMiddleware = createMiddleware().server(async ({ next }) => { + return next() + })` + + const resolveIdMock = vi.fn(async (id: string) => id) + + const compiler = new ServerFnCompiler({ + env: 'client', + loadModule: async () => {}, + lookupKinds: new Set(['Middleware']), + lookupConfigurations: [ + { + libName: '@tanstack/react-start', + rootExport: 'createMiddleware', + }, + ], + resolveId: resolveIdMock, + directive: 'use server', + }) + + await compiler.compile({ + code, + id: 'test.ts', + isProviderFile: false, + }) + + // resolveId should only be called once during init() for the library itself + // It should NOT be called again to resolve the import binding because + // the fast path uses knownRootImports map for O(1) lookup + expect(resolveIdMock).toHaveBeenCalledTimes(1) + expect(resolveIdMock).toHaveBeenCalledWith( + '@tanstack/react-start', + 'test.ts', + ) + }) + + test('should use slow path for factory pattern (resolveId called for import resolution)', async () => { + // This simulates a factory pattern where createMiddleware is re-exported from a local file + const factoryCode = ` + import { createFooMiddleware } from './factory' + const myMiddleware = createFooMiddleware().server(async ({ next }) => { + return next() + })` + + const resolveIdMock = vi.fn(async (id: string) => id) + + const compiler = new ServerFnCompiler({ + env: 'client', + loadModule: async (id) => { + // Simulate the factory module being loaded + if (id === './factory') { + compiler.ingestModule({ + code: ` + import { createMiddleware } from '@tanstack/react-start' + export const createFooMiddleware = createMiddleware + `, + id: './factory', + }) + } + }, + lookupKinds: new Set(['Middleware']), + lookupConfigurations: [ + { + libName: '@tanstack/react-start', + rootExport: 'createMiddleware', + }, + ], + resolveId: resolveIdMock, + directive: 'use server', + }) + + await compiler.compile({ + code: factoryCode, + id: 'test.ts', + isProviderFile: false, + }) + + // resolveId should be called exactly twice: + // 1. Once during init() for '@tanstack/react-start' + // 2. Once to resolve './factory' import (slow path - not in knownRootImports) + // + // Note: The factory module's import from '@tanstack/react-start' ALSO uses + // the fast path (knownRootImports), so no additional resolveId call is needed there. + expect(resolveIdMock).toHaveBeenCalledTimes(2) + expect(resolveIdMock).toHaveBeenNthCalledWith( + 1, + '@tanstack/react-start', + 'test.ts', + ) + expect(resolveIdMock).toHaveBeenNthCalledWith(2, './factory', 'test.ts') + }) }) diff --git a/packages/start-plugin-core/tests/createServerFn/createServerFn.test.ts b/packages/start-plugin-core/tests/createServerFn/createServerFn.test.ts index a8404619554..758a61a94eb 100644 --- a/packages/start-plugin-core/tests/createServerFn/createServerFn.test.ts +++ b/packages/start-plugin-core/tests/createServerFn/createServerFn.test.ts @@ -1,6 +1,6 @@ import { readFile, readdir } from 'node:fs/promises' import path from 'node:path' -import { describe, expect, test } from 'vitest' +import { describe, expect, test, vi } from 'vitest' import { ServerFnCompiler } from '../../src/create-server-fn-plugin/compiler' async function getFilenames() { @@ -214,4 +214,99 @@ describe('createServerFn compiles correctly', async () => { });" `) }) + + test('should use fast path for direct imports from known library (no extra resolveId calls)', async () => { + const code = ` + import { createServerFn } from '@tanstack/react-start' + const myServerFn = createServerFn().handler(async () => { + return 'hello' + })` + + const resolveIdMock = vi.fn(async (id: string) => id) + + const compiler = new ServerFnCompiler({ + env: 'client', + loadModule: async () => {}, + lookupKinds: new Set(['ServerFn']), + lookupConfigurations: [ + { + libName: '@tanstack/react-start', + rootExport: 'createServerFn', + }, + ], + resolveId: resolveIdMock, + directive: 'use server', + }) + + await compiler.compile({ + code, + id: 'test.ts', + isProviderFile: false, + }) + + // resolveId should only be called once during init() for the library itself + // It should NOT be called again to resolve the import binding because + // the fast path uses knownRootImports map for O(1) lookup + expect(resolveIdMock).toHaveBeenCalledTimes(1) + expect(resolveIdMock).toHaveBeenCalledWith( + '@tanstack/react-start', + 'test.ts', + ) + }) + + test('should use slow path for factory pattern (resolveId called for import resolution)', async () => { + // This simulates a factory pattern where createServerFn is re-exported from a local file + const factoryCode = ` + import { createFooServerFn } from './factory' + const myServerFn = createFooServerFn().handler(async () => { + return 'hello' + })` + + const resolveIdMock = vi.fn(async (id: string) => id) + + const compiler = new ServerFnCompiler({ + env: 'client', + loadModule: async (id) => { + // Simulate the factory module being loaded + if (id === './factory') { + compiler.ingestModule({ + code: ` + import { createServerFn } from '@tanstack/react-start' + export const createFooServerFn = createServerFn + `, + id: './factory', + }) + } + }, + lookupKinds: new Set(['ServerFn']), + lookupConfigurations: [ + { + libName: '@tanstack/react-start', + rootExport: 'createServerFn', + }, + ], + resolveId: resolveIdMock, + directive: 'use server', + }) + + await compiler.compile({ + code: factoryCode, + id: 'test.ts', + isProviderFile: false, + }) + + // resolveId should be called exactly twice: + // 1. Once during init() for '@tanstack/react-start' + // 2. Once to resolve './factory' import (slow path - not in knownRootImports) + // + // Note: The factory module's import from '@tanstack/react-start' ALSO uses + // the fast path (knownRootImports), so no additional resolveId call is needed there. + expect(resolveIdMock).toHaveBeenCalledTimes(2) + expect(resolveIdMock).toHaveBeenNthCalledWith( + 1, + '@tanstack/react-start', + 'test.ts', + ) + expect(resolveIdMock).toHaveBeenNthCalledWith(2, './factory', 'test.ts') + }) })