Files
dotfiles/pi/.pi/agent/extensions/claude-agent-sdk/index.ts
2026-04-10 09:01:25 +02:00

1187 lines
37 KiB
TypeScript

/**
* Claude Agent SDK Provider Extension
*
* Routes pi requests through the Claude Code CLI via @anthropic-ai/claude-agent-sdk.
*
* ARCHITECTURE (improved over npm:claude-agent-sdk-pi):
*
* The original extension flattened the entire conversation into a single user
* message on every call, labelling previous tool calls as "non-executable".
* This caused the model to believe its prior tool calls had never run, resulting
* in infinite retry loops.
*
* This version uses proper session persistence + resume:
* - First turn: starts a Claude Code session (persistSession=true, custom UUID)
* - Continuation turns: resumes that session and injects tool results as proper
* SDKUserMessage items with parent_tool_use_id set — giving Claude Code the
* native multi-turn context it needs.
*
* Fallback: if a session can't be resumed (e.g. the file wasn't flushed before
* we closed the query), the full conversation is re-sent as a structured text
* block with clear labels so the model understands what was already done.
*
* BUGS FIXED:
* 1. "Historical tool call (non-executable)" → model thought calls never ran.
* Fixed label + session-resume approach eliminates the confusion entirely.
* 2. Edit tool args: SDK sends {old_string, new_string} but pi's Edit tool
* requires {edits: [{oldText, newText}]}. Now properly wrapped.
*/
import {
calculateCost,
createAssistantMessageEventStream,
getModels,
type AssistantMessage,
type AssistantMessageEventStream,
type Context,
type ImageContent,
type Model,
type SimpleStreamOptions,
type TextContent,
type Tool,
type ToolResultMessage,
type UserMessage,
} from "@mariozechner/pi-ai";
import type { ExtensionAPI } from "@mariozechner/pi-coding-agent";
import {
createSdkMcpServer,
query,
type SDKMessage,
type SDKUserMessage,
type SettingSource,
} from "@anthropic-ai/claude-agent-sdk";
import type {
Base64ImageSource,
ContentBlockParam,
MessageParam,
} from "@anthropic-ai/sdk/resources";
import { pascalCase } from "change-case";
import { existsSync, readFileSync } from "fs";
import { randomUUID } from "crypto";
import { homedir } from "os";
import { dirname, join, relative, resolve } from "path";
// =============================================================================
// Constants
// =============================================================================
const PROVIDER_ID = "claude-agent-sdk";
const SDK_TO_PI_TOOL_NAME: Record<string, string> = {
read: "read",
write: "write",
edit: "edit",
bash: "bash",
grep: "grep",
glob: "find",
};
const PI_TO_SDK_TOOL_NAME: Record<string, string> = {
read: "Read",
write: "Write",
edit: "Edit",
bash: "Bash",
grep: "Grep",
find: "Glob",
glob: "Glob",
};
const DEFAULT_TOOLS = ["Read", "Write", "Edit", "Bash", "Grep", "Glob"];
const BUILTIN_TOOL_NAMES = new Set(Object.keys(PI_TO_SDK_TOOL_NAME));
const TOOL_EXECUTION_DENIED_MESSAGE = "Tool execution is unavailable in this environment.";
const MCP_SERVER_NAME = "custom-tools";
const MCP_TOOL_PREFIX = `mcp__${MCP_SERVER_NAME}__`;
const SKILLS_ALIAS_GLOBAL = "~/.claude/skills";
const SKILLS_ALIAS_PROJECT = ".claude/skills";
const GLOBAL_SKILLS_ROOT = join(homedir(), ".pi", "agent", "skills");
const PROJECT_SKILLS_ROOT = join(process.cwd(), ".pi", "skills");
const GLOBAL_SETTINGS_PATH = join(homedir(), ".pi", "agent", "settings.json");
const PROJECT_SETTINGS_PATH = join(process.cwd(), ".pi", "settings.json");
const GLOBAL_AGENTS_PATH = join(homedir(), ".pi", "agent", "AGENTS.md");
// =============================================================================
// Session Registry
// Tracks active Claude Code sessions keyed by pi conversation identity.
// Enables proper session resumption instead of re-encoding history as text.
// =============================================================================
interface SessionState {
/** The Claude Code session UUID (used for resume). */
claudeCodeSessionId: string;
/**
* How many messages from context.messages we have already sent to Claude
* Code in prior calls. On the next call we send messages[sentMsgCount..].
*/
sentMsgCount: number;
/** Model used when the session was created (for resume validation). */
modelId: string;
}
const sessionRegistry = new Map<string, SessionState>();
/**
* Derives a stable key for the current pi conversation.
* Uses options.sessionId if pi provides one; otherwise fingerprints the first
* user message (first 200 chars of content).
*/
function getConversationKey(context: Context, options?: SimpleStreamOptions): string {
const sid = (options as { sessionId?: string } | undefined)?.sessionId;
if (sid) return `sid:${sid}`;
const firstUser = context.messages.find((m) => m.role === "user");
if (!firstUser) return `empty:${Date.now()}`;
const content =
typeof (firstUser as UserMessage).content === "string"
? ((firstUser as UserMessage).content as string)
: JSON.stringify((firstUser as UserMessage).content);
return `fp:${content.slice(0, 200)}`;
}
// =============================================================================
// Models
// =============================================================================
const MODELS = getModels("anthropic").map((model) => ({
id: model.id,
name: model.name,
reasoning: model.reasoning,
input: model.input,
cost: model.cost,
contextWindow: model.contextWindow,
maxTokens: model.maxTokens,
}));
// =============================================================================
// Tool Name Mapping
// =============================================================================
function mapPiToolNameToSdk(name?: string, customToolNameToSdk?: Map<string, string>): string {
if (!name) return "";
const normalized = name.toLowerCase();
if (customToolNameToSdk) {
const mapped = customToolNameToSdk.get(name) ?? customToolNameToSdk.get(normalized);
if (mapped) return mapped;
}
if (PI_TO_SDK_TOOL_NAME[normalized]) return PI_TO_SDK_TOOL_NAME[normalized];
return pascalCase(name);
}
function mapToolName(name: string, customToolNameToPi?: Map<string, string>): string {
const normalized = name.toLowerCase();
const builtin = SDK_TO_PI_TOOL_NAME[normalized];
if (builtin) return builtin;
if (customToolNameToPi) {
const mapped = customToolNameToPi.get(name) ?? customToolNameToPi.get(normalized);
if (mapped) return mapped;
}
if (normalized.startsWith(MCP_TOOL_PREFIX)) {
return name.slice(MCP_TOOL_PREFIX.length);
}
return name;
}
// =============================================================================
// Tool Argument Mapping
// BUG FIX: Edit previously mapped old_string→oldText as a top-level key,
// but pi's Edit tool requires edits:[{oldText,newText}]. Fixed here.
// =============================================================================
function rewriteSkillAliasPath(pathValue: unknown): unknown {
if (typeof pathValue !== "string") return pathValue;
if (pathValue.startsWith(SKILLS_ALIAS_GLOBAL)) {
return pathValue.replace(SKILLS_ALIAS_GLOBAL, "~/.pi/agent/skills");
}
if (pathValue.startsWith(`./${SKILLS_ALIAS_PROJECT}`)) {
return pathValue.replace(`./${SKILLS_ALIAS_PROJECT}`, PROJECT_SKILLS_ROOT);
}
if (pathValue.startsWith(SKILLS_ALIAS_PROJECT)) {
return pathValue.replace(SKILLS_ALIAS_PROJECT, PROJECT_SKILLS_ROOT);
}
const projectAliasAbs = join(process.cwd(), SKILLS_ALIAS_PROJECT);
if (pathValue.startsWith(projectAliasAbs)) {
return pathValue.replace(projectAliasAbs, PROJECT_SKILLS_ROOT);
}
return pathValue;
}
function mapToolArgs(
toolName: string,
args: Record<string, unknown> | undefined,
allowSkillAliasRewrite = true,
): Record<string, unknown> {
const normalized = toolName.toLowerCase();
const input = args ?? {};
const resolvePath = (value: unknown) => (allowSkillAliasRewrite ? rewriteSkillAliasPath(value) : value);
switch (normalized) {
case "read":
return {
path: resolvePath(input.file_path ?? input.path),
offset: input.offset,
limit: input.limit,
};
case "write":
return {
path: resolvePath(input.file_path ?? input.path),
content: input.content,
};
case "edit": {
// BUG FIX: SDK sends {old_string, new_string} at the top level,
// but pi's Edit tool expects {edits: [{oldText, newText}]}.
const oldText = input.old_string ?? input.oldText ?? input.old_text;
const newText = input.new_string ?? input.newText ?? input.new_text;
if (oldText !== undefined || newText !== undefined) {
return {
path: resolvePath(input.file_path ?? input.path),
edits: [{ oldText: String(oldText ?? ""), newText: String(newText ?? "") }],
};
}
// Already in edits-array format (e.g. from multi-edit calls)
return {
path: resolvePath(input.file_path ?? input.path),
edits: input.edits ?? [],
};
}
case "bash":
return {
command: input.command,
timeout: input.timeout,
};
case "grep":
return {
pattern: input.pattern,
path: resolvePath(input.path),
glob: input.glob,
limit: input.head_limit ?? input.limit,
};
case "find":
return {
pattern: input.pattern,
path: resolvePath(input.path),
};
default:
return input;
}
}
// =============================================================================
// SDK Tool Resolution
// =============================================================================
function resolveSdkTools(context: Context): {
sdkTools: string[];
customTools: Tool[];
customToolNameToSdk: Map<string, string>;
customToolNameToPi: Map<string, string>;
} {
if (!context.tools) {
return {
sdkTools: [...DEFAULT_TOOLS],
customTools: [],
customToolNameToSdk: new Map(),
customToolNameToPi: new Map(),
};
}
const sdkTools = new Set<string>();
const customTools: Tool[] = [];
const customToolNameToSdk = new Map<string, string>();
const customToolNameToPi = new Map<string, string>();
for (const tool of context.tools) {
const normalized = tool.name.toLowerCase();
if (BUILTIN_TOOL_NAMES.has(normalized)) {
const sdkName = PI_TO_SDK_TOOL_NAME[normalized];
if (sdkName) sdkTools.add(sdkName);
continue;
}
const sdkName = `${MCP_TOOL_PREFIX}${tool.name}`;
customTools.push(tool);
customToolNameToSdk.set(tool.name, sdkName);
customToolNameToSdk.set(normalized, sdkName);
customToolNameToPi.set(sdkName, tool.name);
customToolNameToPi.set(sdkName.toLowerCase(), tool.name);
}
return { sdkTools: Array.from(sdkTools), customTools, customToolNameToSdk, customToolNameToPi };
}
function buildCustomToolServers(
customTools: Tool[],
): Record<string, ReturnType<typeof createSdkMcpServer>> | undefined {
if (!customTools.length) return undefined;
const mcpTools = customTools.map((tool) => ({
name: tool.name,
description: tool.description,
inputSchema: tool.parameters as unknown,
handler: async () => ({
content: [{ type: "text", text: TOOL_EXECUTION_DENIED_MESSAGE }],
isError: true,
}),
}));
const server = createSdkMcpServer({
name: MCP_SERVER_NAME,
version: "1.0.0",
tools: mcpTools,
});
return { [MCP_SERVER_NAME]: server };
}
// =============================================================================
// Prompt Stream Builders
// =============================================================================
/**
* Converts pi UserMessage content to Anthropic ContentBlockParam[].
*/
function userContentToBlocks(
content: string | (TextContent | ImageContent)[],
): ContentBlockParam[] {
if (typeof content === "string") {
return [{ type: "text", text: content }];
}
return content.map((block): ContentBlockParam => {
if (block.type === "text") return { type: "text", text: block.text };
// image
const img = block as ImageContent;
return {
type: "image",
source: {
type: "base64",
media_type: img.mimeType as Base64ImageSource["media_type"],
data: img.data,
},
};
});
}
/**
* Converts pi ToolResultMessage content to a plain string for the SDK.
*/
function toolResultContentToText(content: (TextContent | ImageContent)[]): string {
return content
.filter((c) => c.type === "text")
.map((c) => (c as TextContent).text)
.join("\n");
}
/**
* BUILD STRATEGY A — Session Resume
*
* Yields only the new pi messages (since sentMsgCount) as proper SDKUserMessages:
* - user messages → SDKUserMessage with parent_tool_use_id: null
* - toolResult → SDKUserMessage with parent_tool_use_id: toolCallId
* - assistant msgs → skipped (Claude Code generates these)
*
* This gives Claude Code native multi-turn context with real tool results,
* avoiding the text-encoding confusion entirely.
*/
function buildResumeStream(
context: Context,
sentMsgCount: number,
claudeCodeSessionId: string,
): AsyncIterable<SDKUserMessage> {
const newMessages = context.messages.slice(sentMsgCount);
async function* generator() {
for (const msg of newMessages) {
if (msg.role === "assistant") {
// Claude Code generated these; skip — the session already has them.
continue;
}
if (msg.role === "user") {
const umsg = msg as UserMessage;
yield {
type: "user" as const,
message: {
role: "user" as const,
content: userContentToBlocks(umsg.content),
} as MessageParam,
parent_tool_use_id: null,
session_id: claudeCodeSessionId,
} satisfies SDKUserMessage;
continue;
}
if (msg.role === "toolResult") {
const result = msg as ToolResultMessage;
const text = toolResultContentToText(result.content);
yield {
type: "user" as const,
message: {
role: "user" as const,
content: [
{
type: "tool_result" as const,
tool_use_id: result.toolCallId,
content: text,
...(result.isError ? { is_error: true } : {}),
},
],
} as MessageParam,
parent_tool_use_id: result.toolCallId,
session_id: claudeCodeSessionId,
} satisfies SDKUserMessage;
continue;
}
}
}
return generator();
}
/**
* BUILD STRATEGY B — Text Fallback (fresh session or resume failure)
*
* Encodes the full conversation history as structured text inside a single
* user message, then starts a fresh Claude Code session.
*
* BUG FIX: previously used "Historical tool call (non-executable): …" which
* caused the model to believe tool calls had never executed. Now uses
* "[TOOL CALL - EXECUTED]" with clear success indicators.
*/
function buildFallbackTextBlocks(
context: Context,
customToolNameToSdk: Map<string, string> | undefined,
): ContentBlockParam[] {
const blocks: ContentBlockParam[] = [];
const pushText = (text: string) => {
blocks.push({ type: "text", text });
};
const pushImage = (image: ImageContent) => {
blocks.push({
type: "image",
source: {
type: "base64",
media_type: image.mimeType as Base64ImageSource["media_type"],
data: image.data,
},
});
};
const appendContent = (
content: string | Array<{ type: string; text?: string; data?: string; mimeType?: string }>,
): boolean => {
if (typeof content === "string") {
if (content.length > 0) pushText(content);
return content.trim().length > 0;
}
if (!Array.isArray(content)) return false;
let hasText = false;
for (const block of content) {
if (block.type === "text") {
const text = block.text ?? "";
if (text.trim().length > 0) hasText = true;
pushText(text);
} else if (block.type === "image") {
pushImage(block as ImageContent);
} else {
pushText(`[${block.type}]`);
}
}
return hasText;
};
const messages = context.messages;
const hasHistory = messages.length > 0;
if (hasHistory && messages.length > 1) {
// Prefix only when there is actual history (not just the current user turn)
pushText("[CONVERSATION HISTORY - all tool calls below were executed]\n");
}
for (const message of messages) {
if (message.role === "user") {
pushText(`\nUSER:\n`);
const hasText = appendContent((message as UserMessage).content);
if (!hasText) pushText("(see attached image)");
continue;
}
if (message.role === "assistant") {
pushText(`\nASSISTANT:\n`);
const assistantContent = (
message as {
content: Array<{
type: string;
text?: string;
thinking?: string;
name?: string;
arguments?: Record<string, unknown>;
}>;
}
).content;
if (Array.isArray(assistantContent)) {
for (const block of assistantContent) {
if (block.type === "text") pushText(block.text ?? "");
else if (block.type === "thinking") pushText(block.thinking ?? "");
else if (block.type === "toolCall") {
const toolName = mapPiToolNameToSdk(block.name, customToolNameToSdk);
const args = block.arguments ? JSON.stringify(block.arguments, null, 2) : "{}";
// BUG FIX: was "Historical tool call (non-executable)" — model interpreted
// this as the call never having run, causing retry loops.
pushText(`\n[TOOL CALL - EXECUTED]: ${toolName}\n${args}\n`);
}
}
}
continue;
}
if (message.role === "toolResult") {
const result = message as ToolResultMessage;
const toolName = mapPiToolNameToSdk(result.toolName, customToolNameToSdk);
const status = result.isError ? "FAILED" : "SUCCESS";
pushText(`\n[TOOL RESULT - ${status}]: ${toolName}\n`);
const hasText = appendContent(result.content);
if (!hasText) pushText("(see attached image)");
continue;
}
}
// Explicit continuation marker so the model knows it's mid-task
if (messages.length > 1) {
pushText(
"\n\n---\n[TASK STATUS: In progress. The history above shows completed work. Continue from where the assistant left off.]\n",
);
}
if (!blocks.length) return [{ type: "text", text: "" }];
return blocks;
}
function buildFallbackStream(promptBlocks: ContentBlockParam[]): AsyncIterable<SDKUserMessage> {
async function* generator() {
yield {
type: "user" as const,
message: {
role: "user" as const,
content: promptBlocks,
} as MessageParam,
parent_tool_use_id: null,
session_id: "prompt",
} satisfies SDKUserMessage;
}
return generator();
}
// =============================================================================
// Settings
// =============================================================================
const CLAUDE_CONFIG_PATH = join(homedir(), ".claude", "pi-config.json");
type ProviderSettings = {
appendSystemPrompt?: boolean;
settingSources?: SettingSource[];
strictMcpConfig?: boolean;
extraArgs?: Record<string, string | null>;
};
function loadProviderSettings(): ProviderSettings {
const globalSettings = readSettingsFile(GLOBAL_SETTINGS_PATH);
const projectSettings = readSettingsFile(PROJECT_SETTINGS_PATH);
return { ...globalSettings, ...projectSettings };
}
function readSettingsFile(filePath: string): ProviderSettings {
if (!existsSync(filePath)) return {};
try {
const raw = readFileSync(filePath, "utf-8");
const parsed = JSON.parse(raw) as Record<string, unknown>;
const settingsBlock =
(parsed["claudeAgentSdkProvider"] as Record<string, unknown> | undefined) ??
(parsed["claude-agent-sdk-provider"] as Record<string, unknown> | undefined) ??
(parsed["claudeAgentSdk"] as Record<string, unknown> | undefined);
if (!settingsBlock || typeof settingsBlock !== "object") return {};
const appendSystemPrompt =
typeof settingsBlock["appendSystemPrompt"] === "boolean"
? settingsBlock["appendSystemPrompt"]
: undefined;
const settingSourcesRaw = settingsBlock["settingSources"];
const settingSources =
Array.isArray(settingSourcesRaw) &&
settingSourcesRaw.every(
(v) => typeof v === "string" && (v === "user" || v === "project" || v === "local"),
)
? (settingSourcesRaw as SettingSource[])
: undefined;
const strictMcpConfig =
typeof settingsBlock["strictMcpConfig"] === "boolean"
? settingsBlock["strictMcpConfig"]
: undefined;
const extraArgsRaw = settingsBlock["extraArgs"];
const extraArgs =
extraArgsRaw && typeof extraArgsRaw === "object" && !Array.isArray(extraArgsRaw)
? (Object.fromEntries(
Object.entries(extraArgsRaw as Record<string, unknown>)
.filter(([, v]) => typeof v === "string" || v === null)
.map(([k, v]) => [k, v as string | null]),
) as Record<string, string | null>)
: undefined;
return { appendSystemPrompt, settingSources, strictMcpConfig, extraArgs };
} catch {
return {};
}
}
// =============================================================================
// System Prompt Helpers
// =============================================================================
function extractSkillsAppend(systemPrompt?: string): string | undefined {
if (!systemPrompt) return undefined;
const startMarker = "The following skills provide specialized instructions for specific tasks.";
const endMarker = "</available_skills>";
const startIndex = systemPrompt.indexOf(startMarker);
if (startIndex === -1) return undefined;
const endIndex = systemPrompt.indexOf(endMarker, startIndex);
if (endIndex === -1) return undefined;
const skillsBlock = systemPrompt.slice(startIndex, endIndex + endMarker.length).trim();
return rewriteSkillsLocations(skillsBlock);
}
function rewriteSkillsLocations(skillsBlock: string): string {
return skillsBlock.replace(/<location>([^<]+)<\/location>/g, (_match, location: string) => {
let rewritten = location;
if (location.startsWith(GLOBAL_SKILLS_ROOT)) {
const relPath = relative(GLOBAL_SKILLS_ROOT, location).replace(/^\.+/, "");
rewritten = `${SKILLS_ALIAS_GLOBAL}/${relPath}`.replace(/\/\/+/g, "/");
} else if (location.startsWith(PROJECT_SKILLS_ROOT)) {
const relPath = relative(PROJECT_SKILLS_ROOT, location).replace(/^\.+/, "");
rewritten = `${SKILLS_ALIAS_PROJECT}/${relPath}`.replace(/\/\/+/g, "/");
}
return `<location>${rewritten}</location>`;
});
}
function resolveAgentsMdPath(): string | undefined {
const fromCwd = findAgentsMdInParents(process.cwd());
if (fromCwd) return fromCwd;
if (existsSync(GLOBAL_AGENTS_PATH)) return GLOBAL_AGENTS_PATH;
return undefined;
}
function findAgentsMdInParents(startDir: string): string | undefined {
let current = resolve(startDir);
while (true) {
const candidate = join(current, "AGENTS.md");
if (existsSync(candidate)) return candidate;
const parent = dirname(current);
if (parent === current) break;
current = parent;
}
return undefined;
}
function extractAgentsAppend(): string | undefined {
const agentsPath = resolveAgentsMdPath();
if (!agentsPath) return undefined;
try {
const content = readFileSync(agentsPath, "utf-8").trim();
if (!content) return undefined;
const sanitized = sanitizeAgentsContent(content);
return sanitized.length > 0 ? `# CLAUDE.md\n\n${sanitized}` : undefined;
} catch {
return undefined;
}
}
function sanitizeAgentsContent(content: string): string {
let s = content;
s = s.replace(/~\/\.pi\b/gi, "~/.claude");
s = s.replace(/(^|[\s'"`])\.pi\//g, "$1.claude/");
s = s.replace(/\b\.pi\b/gi, ".claude");
s = s.replace(/\bpi\b/gi, "environment");
return s;
}
// =============================================================================
// Thinking Budget
// =============================================================================
type ThinkingLevel = NonNullable<SimpleStreamOptions["reasoning"]>;
type NonXhighThinkingLevel = Exclude<ThinkingLevel, "xhigh">;
const DEFAULT_THINKING_BUDGETS: Record<NonXhighThinkingLevel, number> = {
minimal: 2048,
low: 8192,
medium: 16384,
high: 31999,
};
const OPUS_46_THINKING_BUDGETS: Record<ThinkingLevel, number> = {
minimal: 2048,
low: 8192,
medium: 31999,
high: 63999,
xhigh: 63999,
};
function mapThinkingTokens(
reasoning?: ThinkingLevel,
modelId?: string,
thinkingBudgets?: SimpleStreamOptions["thinkingBudgets"],
): number | undefined {
if (!reasoning) return undefined;
const isOpus46 = modelId?.includes("opus-4-6") || modelId?.includes("opus-4.6");
if (isOpus46) return OPUS_46_THINKING_BUDGETS[reasoning];
const effectiveReasoning: NonXhighThinkingLevel = reasoning === "xhigh" ? "high" : reasoning;
const customBudgets = thinkingBudgets as Partial<Record<NonXhighThinkingLevel, number>> | undefined;
const customBudget = customBudgets?.[effectiveReasoning];
if (typeof customBudget === "number" && Number.isFinite(customBudget) && customBudget > 0) {
return customBudget;
}
return DEFAULT_THINKING_BUDGETS[effectiveReasoning];
}
// =============================================================================
// Misc Helpers
// =============================================================================
function parsePartialJson(input: string, fallback: Record<string, unknown>): Record<string, unknown> {
if (!input) return fallback;
try {
return JSON.parse(input);
} catch {
return fallback;
}
}
function mapStopReason(reason: string | undefined): "stop" | "length" | "toolUse" {
switch (reason) {
case "tool_use":
return "toolUse";
case "max_tokens":
return "length";
default:
return "stop";
}
}
// =============================================================================
// Main Stream Function
// =============================================================================
function streamClaudeAgentSdk(
model: Model<any>,
context: Context,
options?: SimpleStreamOptions,
): AssistantMessageEventStream {
const stream = createAssistantMessageEventStream();
(async () => {
const output: AssistantMessage = {
role: "assistant",
content: [],
api: model.api,
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "stop",
timestamp: Date.now(),
};
let sdkQuery: ReturnType<typeof query> | undefined;
let wasAborted = false;
const requestAbort = () => {
if (!sdkQuery) return;
void sdkQuery.interrupt().catch(() => {
try {
sdkQuery?.close();
} catch {
// ignore
}
});
};
const onAbort = () => {
wasAborted = true;
requestAbort();
};
if (options?.signal) {
if (options.signal.aborted) onAbort();
else options.signal.addEventListener("abort", onAbort, { once: true });
}
const blocks = output.content as Array<
| { type: "text"; text: string; index: number }
| { type: "thinking"; thinking: string; thinkingSignature?: string; index: number }
| {
type: "toolCall";
id: string;
name: string;
arguments: Record<string, unknown>;
partialJson: string;
index: number;
}
>;
let started = false;
let sawStreamEvent = false;
let sawToolCall = false;
let shouldStopEarly = false;
try {
const { sdkTools, customTools, customToolNameToSdk, customToolNameToPi } =
resolveSdkTools(context);
const cwd = (options as { cwd?: string } | undefined)?.cwd ?? process.cwd();
const mcpServers = buildCustomToolServers(customTools);
const providerSettings = loadProviderSettings();
const appendSystemPrompt = providerSettings.appendSystemPrompt !== false;
const agentsAppend = appendSystemPrompt ? extractAgentsAppend() : undefined;
const skillsAppend = appendSystemPrompt
? extractSkillsAppend(context.systemPrompt)
: undefined;
const appendParts = [agentsAppend, skillsAppend].filter((p): p is string => Boolean(p));
const systemPromptAppend = appendParts.length > 0 ? appendParts.join("\n\n") : undefined;
const allowSkillAliasRewrite = Boolean(skillsAppend);
const settingSources: SettingSource[] | undefined = appendSystemPrompt
? undefined
: providerSettings.settingSources ?? ["user", "project"];
const strictMcpConfigEnabled = !appendSystemPrompt && providerSettings.strictMcpConfig !== false;
const generatedExtraArgs = strictMcpConfigEnabled ? { "strict-mcp-config": null } : undefined;
const extraArgs = { ...generatedExtraArgs, ...providerSettings.extraArgs };
// ----------------------------------------------------------------
// Decide: resume existing session or start fresh?
// ----------------------------------------------------------------
const convKey = getConversationKey(context, options);
const existingSession = sessionRegistry.get(convKey);
let prompt: AsyncIterable<SDKUserMessage>;
let claudeCodeSessionId: string;
let newSentMsgCount: number;
let isResume: boolean;
if (existingSession && context.messages.length > existingSession.sentMsgCount) {
// Continuation: resume and inject new tool results / user messages.
claudeCodeSessionId = existingSession.claudeCodeSessionId;
newSentMsgCount = context.messages.length;
isResume = true;
prompt = buildResumeStream(context, existingSession.sentMsgCount, claudeCodeSessionId);
} else {
// Fresh session (first turn, or after session was cleaned up).
claudeCodeSessionId = randomUUID();
newSentMsgCount = context.messages.length;
isResume = false;
const fallbackBlocks = buildFallbackTextBlocks(context, customToolNameToSdk);
prompt = buildFallbackStream(fallbackBlocks);
}
const queryOptions: NonNullable<Parameters<typeof query>[0]["options"]> = {
cwd,
tools: sdkTools,
permissionMode: "dontAsk",
includePartialMessages: true,
canUseTool: async () => ({
behavior: "deny",
message: TOOL_EXECUTION_DENIED_MESSAGE,
}),
systemPrompt: {
type: "preset",
preset: "claude_code",
append: systemPromptAppend ?? undefined,
},
...(settingSources ? { settingSources } : {}),
...(Object.keys(extraArgs).length > 0 ? { extraArgs } : {}),
...(mcpServers ? { mcpServers } : {}),
// Session management
...(isResume
? { resume: claudeCodeSessionId }
: { sessionId: claudeCodeSessionId, persistSession: true }),
};
const maxThinkingTokens = mapThinkingTokens(options?.reasoning, model.id, options?.thinkingBudgets);
if (maxThinkingTokens != null) {
queryOptions.maxThinkingTokens = maxThinkingTokens;
}
// Register / update session state BEFORE the query so we have it
// even if an error occurs after the first message_stop.
sessionRegistry.set(convKey, {
claudeCodeSessionId,
sentMsgCount: newSentMsgCount,
modelId: model.id,
});
sdkQuery = query({ prompt, options: queryOptions });
if (wasAborted) requestAbort();
for await (const message of sdkQuery) {
if (!started) {
stream.push({ type: "start", partial: output });
started = true;
}
switch (message.type) {
case "stream_event": {
sawStreamEvent = true;
const event = (message as SDKMessage & { event: any }).event;
if (event?.type === "message_start") {
const usage = event.message?.usage;
output.usage.input = usage?.input_tokens ?? 0;
output.usage.output = usage?.output_tokens ?? 0;
output.usage.cacheRead = usage?.cache_read_input_tokens ?? 0;
output.usage.cacheWrite = usage?.cache_creation_input_tokens ?? 0;
output.usage.totalTokens =
output.usage.input +
output.usage.output +
output.usage.cacheRead +
output.usage.cacheWrite;
calculateCost(model, output.usage);
break;
}
if (event?.type === "content_block_start") {
if (event.content_block?.type === "text") {
blocks.push({ type: "text", text: "", index: event.index });
stream.push({
type: "text_start",
contentIndex: output.content.length - 1,
partial: output,
});
} else if (event.content_block?.type === "thinking") {
blocks.push({
type: "thinking",
thinking: "",
thinkingSignature: "",
index: event.index,
});
stream.push({
type: "thinking_start",
contentIndex: output.content.length - 1,
partial: output,
});
} else if (event.content_block?.type === "tool_use") {
sawToolCall = true;
blocks.push({
type: "toolCall",
id: event.content_block.id,
name: mapToolName(event.content_block.name, customToolNameToPi),
arguments: (event.content_block.input as Record<string, unknown>) ?? {},
partialJson: "",
index: event.index,
});
stream.push({
type: "toolcall_start",
contentIndex: output.content.length - 1,
partial: output,
});
}
break;
}
if (event?.type === "content_block_delta") {
if (event.delta?.type === "text_delta") {
const idx = blocks.findIndex((b) => b.index === event.index);
const block = blocks[idx];
if (block?.type === "text") {
block.text += event.delta.text;
stream.push({
type: "text_delta",
contentIndex: idx,
delta: event.delta.text,
partial: output,
});
}
} else if (event.delta?.type === "thinking_delta") {
const idx = blocks.findIndex((b) => b.index === event.index);
const block = blocks[idx];
if (block?.type === "thinking") {
block.thinking += event.delta.thinking;
stream.push({
type: "thinking_delta",
contentIndex: idx,
delta: event.delta.thinking,
partial: output,
});
}
} else if (event.delta?.type === "input_json_delta") {
const idx = blocks.findIndex((b) => b.index === event.index);
const block = blocks[idx];
if (block?.type === "toolCall") {
block.partialJson += event.delta.partial_json;
block.arguments = parsePartialJson(block.partialJson, block.arguments);
stream.push({
type: "toolcall_delta",
contentIndex: idx,
delta: event.delta.partial_json,
partial: output,
});
}
} else if (event.delta?.type === "signature_delta") {
const idx = blocks.findIndex((b) => b.index === event.index);
const block = blocks[idx];
if (block?.type === "thinking") {
block.thinkingSignature =
(block.thinkingSignature ?? "") + event.delta.signature;
}
}
break;
}
if (event?.type === "content_block_stop") {
const idx = blocks.findIndex((b) => b.index === event.index);
const block = blocks[idx];
if (!block) break;
delete (block as any).index;
if (block.type === "text") {
stream.push({
type: "text_end",
contentIndex: idx,
content: block.text,
partial: output,
});
} else if (block.type === "thinking") {
stream.push({
type: "thinking_end",
contentIndex: idx,
content: block.thinking,
partial: output,
});
} else if (block.type === "toolCall") {
sawToolCall = true;
block.arguments = mapToolArgs(
block.name,
parsePartialJson(block.partialJson, block.arguments),
allowSkillAliasRewrite,
);
delete (block as any).partialJson;
stream.push({
type: "toolcall_end",
contentIndex: idx,
toolCall: block,
partial: output,
});
}
break;
}
if (event?.type === "message_delta") {
output.stopReason = mapStopReason(event.delta?.stop_reason);
const usage = event.usage ?? {};
if (usage.input_tokens != null) output.usage.input = usage.input_tokens;
if (usage.output_tokens != null) output.usage.output = usage.output_tokens;
if (usage.cache_read_input_tokens != null)
output.usage.cacheRead = usage.cache_read_input_tokens;
if (usage.cache_creation_input_tokens != null)
output.usage.cacheWrite = usage.cache_creation_input_tokens;
output.usage.totalTokens =
output.usage.input +
output.usage.output +
output.usage.cacheRead +
output.usage.cacheWrite;
calculateCost(model, output.usage);
break;
}
if (event?.type === "message_stop" && sawToolCall) {
output.stopReason = "toolUse";
shouldStopEarly = true;
break;
}
break;
}
case "result": {
if (!sawStreamEvent && message.subtype === "success") {
output.content.push({ type: "text", text: (message as any).result || "" });
}
break;
}
}
if (shouldStopEarly) break;
}
if (wasAborted || options?.signal?.aborted) {
output.stopReason = "aborted";
output.errorMessage = "Operation aborted";
stream.push({ type: "error", reason: "aborted", error: output });
stream.end();
return;
}
// Clean up session registry when the conversation ends naturally
// (no tool calls → final answer → conversation over).
if (output.stopReason === "stop" || output.stopReason === "length") {
sessionRegistry.delete(convKey);
}
stream.push({
type: "done",
reason:
output.stopReason === "toolUse"
? "toolUse"
: output.stopReason === "length"
? "length"
: "stop",
message: output,
});
stream.end();
} catch (error) {
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
output.errorMessage = error instanceof Error ? error.message : String(error);
stream.push({
type: "error",
reason: output.stopReason as "aborted" | "error",
error: output,
});
stream.end();
} finally {
if (options?.signal) {
options.signal.removeEventListener("abort", onAbort);
}
sdkQuery?.close();
}
})();
return stream;
}
// =============================================================================
// Extension Entry Point
// =============================================================================
export default function (pi: ExtensionAPI) {
pi.registerProvider(PROVIDER_ID, {
baseUrl: "claude-agent-sdk",
apiKey: "ANTHROPIC_API_KEY",
api: "claude-agent-sdk",
models: MODELS,
streamSimple: streamClaudeAgentSdk,
});
}