1187 lines
37 KiB
TypeScript
1187 lines
37 KiB
TypeScript
/**
|
|
* Claude Agent SDK Provider Extension
|
|
*
|
|
* Routes pi requests through the Claude Code CLI via @anthropic-ai/claude-agent-sdk.
|
|
*
|
|
* ARCHITECTURE (improved over npm:claude-agent-sdk-pi):
|
|
*
|
|
* The original extension flattened the entire conversation into a single user
|
|
* message on every call, labelling previous tool calls as "non-executable".
|
|
* This caused the model to believe its prior tool calls had never run, resulting
|
|
* in infinite retry loops.
|
|
*
|
|
* This version uses proper session persistence + resume:
|
|
* - First turn: starts a Claude Code session (persistSession=true, custom UUID)
|
|
* - Continuation turns: resumes that session and injects tool results as proper
|
|
* SDKUserMessage items with parent_tool_use_id set — giving Claude Code the
|
|
* native multi-turn context it needs.
|
|
*
|
|
* Fallback: if a session can't be resumed (e.g. the file wasn't flushed before
|
|
* we closed the query), the full conversation is re-sent as a structured text
|
|
* block with clear labels so the model understands what was already done.
|
|
*
|
|
* BUGS FIXED:
|
|
* 1. "Historical tool call (non-executable)" → model thought calls never ran.
|
|
* Fixed label + session-resume approach eliminates the confusion entirely.
|
|
* 2. Edit tool args: SDK sends {old_string, new_string} but pi's Edit tool
|
|
* requires {edits: [{oldText, newText}]}. Now properly wrapped.
|
|
*/
|
|
|
|
import {
|
|
calculateCost,
|
|
createAssistantMessageEventStream,
|
|
getModels,
|
|
type AssistantMessage,
|
|
type AssistantMessageEventStream,
|
|
type Context,
|
|
type ImageContent,
|
|
type Model,
|
|
type SimpleStreamOptions,
|
|
type TextContent,
|
|
type Tool,
|
|
type ToolResultMessage,
|
|
type UserMessage,
|
|
} from "@mariozechner/pi-ai";
|
|
import type { ExtensionAPI } from "@mariozechner/pi-coding-agent";
|
|
import {
|
|
createSdkMcpServer,
|
|
query,
|
|
type SDKMessage,
|
|
type SDKUserMessage,
|
|
type SettingSource,
|
|
} from "@anthropic-ai/claude-agent-sdk";
|
|
import type {
|
|
Base64ImageSource,
|
|
ContentBlockParam,
|
|
MessageParam,
|
|
} from "@anthropic-ai/sdk/resources";
|
|
import { pascalCase } from "change-case";
|
|
import { existsSync, readFileSync } from "fs";
|
|
import { randomUUID } from "crypto";
|
|
import { homedir } from "os";
|
|
import { dirname, join, relative, resolve } from "path";
|
|
|
|
// =============================================================================
|
|
// Constants
|
|
// =============================================================================
|
|
|
|
const PROVIDER_ID = "claude-agent-sdk";
|
|
|
|
const SDK_TO_PI_TOOL_NAME: Record<string, string> = {
|
|
read: "read",
|
|
write: "write",
|
|
edit: "edit",
|
|
bash: "bash",
|
|
grep: "grep",
|
|
glob: "find",
|
|
};
|
|
|
|
const PI_TO_SDK_TOOL_NAME: Record<string, string> = {
|
|
read: "Read",
|
|
write: "Write",
|
|
edit: "Edit",
|
|
bash: "Bash",
|
|
grep: "Grep",
|
|
find: "Glob",
|
|
glob: "Glob",
|
|
};
|
|
|
|
const DEFAULT_TOOLS = ["Read", "Write", "Edit", "Bash", "Grep", "Glob"];
|
|
const BUILTIN_TOOL_NAMES = new Set(Object.keys(PI_TO_SDK_TOOL_NAME));
|
|
const TOOL_EXECUTION_DENIED_MESSAGE = "Tool execution is unavailable in this environment.";
|
|
const MCP_SERVER_NAME = "custom-tools";
|
|
const MCP_TOOL_PREFIX = `mcp__${MCP_SERVER_NAME}__`;
|
|
|
|
const SKILLS_ALIAS_GLOBAL = "~/.claude/skills";
|
|
const SKILLS_ALIAS_PROJECT = ".claude/skills";
|
|
const GLOBAL_SKILLS_ROOT = join(homedir(), ".pi", "agent", "skills");
|
|
const PROJECT_SKILLS_ROOT = join(process.cwd(), ".pi", "skills");
|
|
const GLOBAL_SETTINGS_PATH = join(homedir(), ".pi", "agent", "settings.json");
|
|
const PROJECT_SETTINGS_PATH = join(process.cwd(), ".pi", "settings.json");
|
|
const GLOBAL_AGENTS_PATH = join(homedir(), ".pi", "agent", "AGENTS.md");
|
|
|
|
// =============================================================================
|
|
// Session Registry
|
|
// Tracks active Claude Code sessions keyed by pi conversation identity.
|
|
// Enables proper session resumption instead of re-encoding history as text.
|
|
// =============================================================================
|
|
|
|
interface SessionState {
|
|
/** The Claude Code session UUID (used for resume). */
|
|
claudeCodeSessionId: string;
|
|
/**
|
|
* How many messages from context.messages we have already sent to Claude
|
|
* Code in prior calls. On the next call we send messages[sentMsgCount..].
|
|
*/
|
|
sentMsgCount: number;
|
|
/** Model used when the session was created (for resume validation). */
|
|
modelId: string;
|
|
}
|
|
|
|
const sessionRegistry = new Map<string, SessionState>();
|
|
|
|
/**
|
|
* Derives a stable key for the current pi conversation.
|
|
* Uses options.sessionId if pi provides one; otherwise fingerprints the first
|
|
* user message (first 200 chars of content).
|
|
*/
|
|
function getConversationKey(context: Context, options?: SimpleStreamOptions): string {
|
|
const sid = (options as { sessionId?: string } | undefined)?.sessionId;
|
|
if (sid) return `sid:${sid}`;
|
|
|
|
const firstUser = context.messages.find((m) => m.role === "user");
|
|
if (!firstUser) return `empty:${Date.now()}`;
|
|
|
|
const content =
|
|
typeof (firstUser as UserMessage).content === "string"
|
|
? ((firstUser as UserMessage).content as string)
|
|
: JSON.stringify((firstUser as UserMessage).content);
|
|
|
|
return `fp:${content.slice(0, 200)}`;
|
|
}
|
|
|
|
// =============================================================================
|
|
// Models
|
|
// =============================================================================
|
|
|
|
const MODELS = getModels("anthropic").map((model) => ({
|
|
id: model.id,
|
|
name: model.name,
|
|
reasoning: model.reasoning,
|
|
input: model.input,
|
|
cost: model.cost,
|
|
contextWindow: model.contextWindow,
|
|
maxTokens: model.maxTokens,
|
|
}));
|
|
|
|
// =============================================================================
|
|
// Tool Name Mapping
|
|
// =============================================================================
|
|
|
|
function mapPiToolNameToSdk(name?: string, customToolNameToSdk?: Map<string, string>): string {
|
|
if (!name) return "";
|
|
const normalized = name.toLowerCase();
|
|
if (customToolNameToSdk) {
|
|
const mapped = customToolNameToSdk.get(name) ?? customToolNameToSdk.get(normalized);
|
|
if (mapped) return mapped;
|
|
}
|
|
if (PI_TO_SDK_TOOL_NAME[normalized]) return PI_TO_SDK_TOOL_NAME[normalized];
|
|
return pascalCase(name);
|
|
}
|
|
|
|
function mapToolName(name: string, customToolNameToPi?: Map<string, string>): string {
|
|
const normalized = name.toLowerCase();
|
|
const builtin = SDK_TO_PI_TOOL_NAME[normalized];
|
|
if (builtin) return builtin;
|
|
if (customToolNameToPi) {
|
|
const mapped = customToolNameToPi.get(name) ?? customToolNameToPi.get(normalized);
|
|
if (mapped) return mapped;
|
|
}
|
|
if (normalized.startsWith(MCP_TOOL_PREFIX)) {
|
|
return name.slice(MCP_TOOL_PREFIX.length);
|
|
}
|
|
return name;
|
|
}
|
|
|
|
// =============================================================================
|
|
// Tool Argument Mapping
|
|
// BUG FIX: Edit previously mapped old_string→oldText as a top-level key,
|
|
// but pi's Edit tool requires edits:[{oldText,newText}]. Fixed here.
|
|
// =============================================================================
|
|
|
|
function rewriteSkillAliasPath(pathValue: unknown): unknown {
|
|
if (typeof pathValue !== "string") return pathValue;
|
|
if (pathValue.startsWith(SKILLS_ALIAS_GLOBAL)) {
|
|
return pathValue.replace(SKILLS_ALIAS_GLOBAL, "~/.pi/agent/skills");
|
|
}
|
|
if (pathValue.startsWith(`./${SKILLS_ALIAS_PROJECT}`)) {
|
|
return pathValue.replace(`./${SKILLS_ALIAS_PROJECT}`, PROJECT_SKILLS_ROOT);
|
|
}
|
|
if (pathValue.startsWith(SKILLS_ALIAS_PROJECT)) {
|
|
return pathValue.replace(SKILLS_ALIAS_PROJECT, PROJECT_SKILLS_ROOT);
|
|
}
|
|
const projectAliasAbs = join(process.cwd(), SKILLS_ALIAS_PROJECT);
|
|
if (pathValue.startsWith(projectAliasAbs)) {
|
|
return pathValue.replace(projectAliasAbs, PROJECT_SKILLS_ROOT);
|
|
}
|
|
return pathValue;
|
|
}
|
|
|
|
function mapToolArgs(
|
|
toolName: string,
|
|
args: Record<string, unknown> | undefined,
|
|
allowSkillAliasRewrite = true,
|
|
): Record<string, unknown> {
|
|
const normalized = toolName.toLowerCase();
|
|
const input = args ?? {};
|
|
const resolvePath = (value: unknown) => (allowSkillAliasRewrite ? rewriteSkillAliasPath(value) : value);
|
|
|
|
switch (normalized) {
|
|
case "read":
|
|
return {
|
|
path: resolvePath(input.file_path ?? input.path),
|
|
offset: input.offset,
|
|
limit: input.limit,
|
|
};
|
|
|
|
case "write":
|
|
return {
|
|
path: resolvePath(input.file_path ?? input.path),
|
|
content: input.content,
|
|
};
|
|
|
|
case "edit": {
|
|
// BUG FIX: SDK sends {old_string, new_string} at the top level,
|
|
// but pi's Edit tool expects {edits: [{oldText, newText}]}.
|
|
const oldText = input.old_string ?? input.oldText ?? input.old_text;
|
|
const newText = input.new_string ?? input.newText ?? input.new_text;
|
|
if (oldText !== undefined || newText !== undefined) {
|
|
return {
|
|
path: resolvePath(input.file_path ?? input.path),
|
|
edits: [{ oldText: String(oldText ?? ""), newText: String(newText ?? "") }],
|
|
};
|
|
}
|
|
// Already in edits-array format (e.g. from multi-edit calls)
|
|
return {
|
|
path: resolvePath(input.file_path ?? input.path),
|
|
edits: input.edits ?? [],
|
|
};
|
|
}
|
|
|
|
case "bash":
|
|
return {
|
|
command: input.command,
|
|
timeout: input.timeout,
|
|
};
|
|
|
|
case "grep":
|
|
return {
|
|
pattern: input.pattern,
|
|
path: resolvePath(input.path),
|
|
glob: input.glob,
|
|
limit: input.head_limit ?? input.limit,
|
|
};
|
|
|
|
case "find":
|
|
return {
|
|
pattern: input.pattern,
|
|
path: resolvePath(input.path),
|
|
};
|
|
|
|
default:
|
|
return input;
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// SDK Tool Resolution
|
|
// =============================================================================
|
|
|
|
function resolveSdkTools(context: Context): {
|
|
sdkTools: string[];
|
|
customTools: Tool[];
|
|
customToolNameToSdk: Map<string, string>;
|
|
customToolNameToPi: Map<string, string>;
|
|
} {
|
|
if (!context.tools) {
|
|
return {
|
|
sdkTools: [...DEFAULT_TOOLS],
|
|
customTools: [],
|
|
customToolNameToSdk: new Map(),
|
|
customToolNameToPi: new Map(),
|
|
};
|
|
}
|
|
|
|
const sdkTools = new Set<string>();
|
|
const customTools: Tool[] = [];
|
|
const customToolNameToSdk = new Map<string, string>();
|
|
const customToolNameToPi = new Map<string, string>();
|
|
|
|
for (const tool of context.tools) {
|
|
const normalized = tool.name.toLowerCase();
|
|
if (BUILTIN_TOOL_NAMES.has(normalized)) {
|
|
const sdkName = PI_TO_SDK_TOOL_NAME[normalized];
|
|
if (sdkName) sdkTools.add(sdkName);
|
|
continue;
|
|
}
|
|
const sdkName = `${MCP_TOOL_PREFIX}${tool.name}`;
|
|
customTools.push(tool);
|
|
customToolNameToSdk.set(tool.name, sdkName);
|
|
customToolNameToSdk.set(normalized, sdkName);
|
|
customToolNameToPi.set(sdkName, tool.name);
|
|
customToolNameToPi.set(sdkName.toLowerCase(), tool.name);
|
|
}
|
|
|
|
return { sdkTools: Array.from(sdkTools), customTools, customToolNameToSdk, customToolNameToPi };
|
|
}
|
|
|
|
function buildCustomToolServers(
|
|
customTools: Tool[],
|
|
): Record<string, ReturnType<typeof createSdkMcpServer>> | undefined {
|
|
if (!customTools.length) return undefined;
|
|
|
|
const mcpTools = customTools.map((tool) => ({
|
|
name: tool.name,
|
|
description: tool.description,
|
|
inputSchema: tool.parameters as unknown,
|
|
handler: async () => ({
|
|
content: [{ type: "text", text: TOOL_EXECUTION_DENIED_MESSAGE }],
|
|
isError: true,
|
|
}),
|
|
}));
|
|
|
|
const server = createSdkMcpServer({
|
|
name: MCP_SERVER_NAME,
|
|
version: "1.0.0",
|
|
tools: mcpTools,
|
|
});
|
|
|
|
return { [MCP_SERVER_NAME]: server };
|
|
}
|
|
|
|
// =============================================================================
|
|
// Prompt Stream Builders
|
|
// =============================================================================
|
|
|
|
/**
|
|
* Converts pi UserMessage content to Anthropic ContentBlockParam[].
|
|
*/
|
|
function userContentToBlocks(
|
|
content: string | (TextContent | ImageContent)[],
|
|
): ContentBlockParam[] {
|
|
if (typeof content === "string") {
|
|
return [{ type: "text", text: content }];
|
|
}
|
|
return content.map((block): ContentBlockParam => {
|
|
if (block.type === "text") return { type: "text", text: block.text };
|
|
// image
|
|
const img = block as ImageContent;
|
|
return {
|
|
type: "image",
|
|
source: {
|
|
type: "base64",
|
|
media_type: img.mimeType as Base64ImageSource["media_type"],
|
|
data: img.data,
|
|
},
|
|
};
|
|
});
|
|
}
|
|
|
|
/**
|
|
* Converts pi ToolResultMessage content to a plain string for the SDK.
|
|
*/
|
|
function toolResultContentToText(content: (TextContent | ImageContent)[]): string {
|
|
return content
|
|
.filter((c) => c.type === "text")
|
|
.map((c) => (c as TextContent).text)
|
|
.join("\n");
|
|
}
|
|
|
|
/**
|
|
* BUILD STRATEGY A — Session Resume
|
|
*
|
|
* Yields only the new pi messages (since sentMsgCount) as proper SDKUserMessages:
|
|
* - user messages → SDKUserMessage with parent_tool_use_id: null
|
|
* - toolResult → SDKUserMessage with parent_tool_use_id: toolCallId
|
|
* - assistant msgs → skipped (Claude Code generates these)
|
|
*
|
|
* This gives Claude Code native multi-turn context with real tool results,
|
|
* avoiding the text-encoding confusion entirely.
|
|
*/
|
|
function buildResumeStream(
|
|
context: Context,
|
|
sentMsgCount: number,
|
|
claudeCodeSessionId: string,
|
|
): AsyncIterable<SDKUserMessage> {
|
|
const newMessages = context.messages.slice(sentMsgCount);
|
|
|
|
async function* generator() {
|
|
for (const msg of newMessages) {
|
|
if (msg.role === "assistant") {
|
|
// Claude Code generated these; skip — the session already has them.
|
|
continue;
|
|
}
|
|
|
|
if (msg.role === "user") {
|
|
const umsg = msg as UserMessage;
|
|
yield {
|
|
type: "user" as const,
|
|
message: {
|
|
role: "user" as const,
|
|
content: userContentToBlocks(umsg.content),
|
|
} as MessageParam,
|
|
parent_tool_use_id: null,
|
|
session_id: claudeCodeSessionId,
|
|
} satisfies SDKUserMessage;
|
|
continue;
|
|
}
|
|
|
|
if (msg.role === "toolResult") {
|
|
const result = msg as ToolResultMessage;
|
|
const text = toolResultContentToText(result.content);
|
|
yield {
|
|
type: "user" as const,
|
|
message: {
|
|
role: "user" as const,
|
|
content: [
|
|
{
|
|
type: "tool_result" as const,
|
|
tool_use_id: result.toolCallId,
|
|
content: text,
|
|
...(result.isError ? { is_error: true } : {}),
|
|
},
|
|
],
|
|
} as MessageParam,
|
|
parent_tool_use_id: result.toolCallId,
|
|
session_id: claudeCodeSessionId,
|
|
} satisfies SDKUserMessage;
|
|
continue;
|
|
}
|
|
}
|
|
}
|
|
|
|
return generator();
|
|
}
|
|
|
|
/**
|
|
* BUILD STRATEGY B — Text Fallback (fresh session or resume failure)
|
|
*
|
|
* Encodes the full conversation history as structured text inside a single
|
|
* user message, then starts a fresh Claude Code session.
|
|
*
|
|
* BUG FIX: previously used "Historical tool call (non-executable): …" which
|
|
* caused the model to believe tool calls had never executed. Now uses
|
|
* "[TOOL CALL - EXECUTED]" with clear success indicators.
|
|
*/
|
|
function buildFallbackTextBlocks(
|
|
context: Context,
|
|
customToolNameToSdk: Map<string, string> | undefined,
|
|
): ContentBlockParam[] {
|
|
const blocks: ContentBlockParam[] = [];
|
|
|
|
const pushText = (text: string) => {
|
|
blocks.push({ type: "text", text });
|
|
};
|
|
|
|
const pushImage = (image: ImageContent) => {
|
|
blocks.push({
|
|
type: "image",
|
|
source: {
|
|
type: "base64",
|
|
media_type: image.mimeType as Base64ImageSource["media_type"],
|
|
data: image.data,
|
|
},
|
|
});
|
|
};
|
|
|
|
const appendContent = (
|
|
content: string | Array<{ type: string; text?: string; data?: string; mimeType?: string }>,
|
|
): boolean => {
|
|
if (typeof content === "string") {
|
|
if (content.length > 0) pushText(content);
|
|
return content.trim().length > 0;
|
|
}
|
|
if (!Array.isArray(content)) return false;
|
|
let hasText = false;
|
|
for (const block of content) {
|
|
if (block.type === "text") {
|
|
const text = block.text ?? "";
|
|
if (text.trim().length > 0) hasText = true;
|
|
pushText(text);
|
|
} else if (block.type === "image") {
|
|
pushImage(block as ImageContent);
|
|
} else {
|
|
pushText(`[${block.type}]`);
|
|
}
|
|
}
|
|
return hasText;
|
|
};
|
|
|
|
const messages = context.messages;
|
|
const hasHistory = messages.length > 0;
|
|
|
|
if (hasHistory && messages.length > 1) {
|
|
// Prefix only when there is actual history (not just the current user turn)
|
|
pushText("[CONVERSATION HISTORY - all tool calls below were executed]\n");
|
|
}
|
|
|
|
for (const message of messages) {
|
|
if (message.role === "user") {
|
|
pushText(`\nUSER:\n`);
|
|
const hasText = appendContent((message as UserMessage).content);
|
|
if (!hasText) pushText("(see attached image)");
|
|
continue;
|
|
}
|
|
|
|
if (message.role === "assistant") {
|
|
pushText(`\nASSISTANT:\n`);
|
|
const assistantContent = (
|
|
message as {
|
|
content: Array<{
|
|
type: string;
|
|
text?: string;
|
|
thinking?: string;
|
|
name?: string;
|
|
arguments?: Record<string, unknown>;
|
|
}>;
|
|
}
|
|
).content;
|
|
if (Array.isArray(assistantContent)) {
|
|
for (const block of assistantContent) {
|
|
if (block.type === "text") pushText(block.text ?? "");
|
|
else if (block.type === "thinking") pushText(block.thinking ?? "");
|
|
else if (block.type === "toolCall") {
|
|
const toolName = mapPiToolNameToSdk(block.name, customToolNameToSdk);
|
|
const args = block.arguments ? JSON.stringify(block.arguments, null, 2) : "{}";
|
|
// BUG FIX: was "Historical tool call (non-executable)" — model interpreted
|
|
// this as the call never having run, causing retry loops.
|
|
pushText(`\n[TOOL CALL - EXECUTED]: ${toolName}\n${args}\n`);
|
|
}
|
|
}
|
|
}
|
|
continue;
|
|
}
|
|
|
|
if (message.role === "toolResult") {
|
|
const result = message as ToolResultMessage;
|
|
const toolName = mapPiToolNameToSdk(result.toolName, customToolNameToSdk);
|
|
const status = result.isError ? "FAILED" : "SUCCESS";
|
|
pushText(`\n[TOOL RESULT - ${status}]: ${toolName}\n`);
|
|
const hasText = appendContent(result.content);
|
|
if (!hasText) pushText("(see attached image)");
|
|
continue;
|
|
}
|
|
}
|
|
|
|
// Explicit continuation marker so the model knows it's mid-task
|
|
if (messages.length > 1) {
|
|
pushText(
|
|
"\n\n---\n[TASK STATUS: In progress. The history above shows completed work. Continue from where the assistant left off.]\n",
|
|
);
|
|
}
|
|
|
|
if (!blocks.length) return [{ type: "text", text: "" }];
|
|
return blocks;
|
|
}
|
|
|
|
function buildFallbackStream(promptBlocks: ContentBlockParam[]): AsyncIterable<SDKUserMessage> {
|
|
async function* generator() {
|
|
yield {
|
|
type: "user" as const,
|
|
message: {
|
|
role: "user" as const,
|
|
content: promptBlocks,
|
|
} as MessageParam,
|
|
parent_tool_use_id: null,
|
|
session_id: "prompt",
|
|
} satisfies SDKUserMessage;
|
|
}
|
|
return generator();
|
|
}
|
|
|
|
// =============================================================================
|
|
// Settings
|
|
// =============================================================================
|
|
|
|
const CLAUDE_CONFIG_PATH = join(homedir(), ".claude", "pi-config.json");
|
|
|
|
type ProviderSettings = {
|
|
appendSystemPrompt?: boolean;
|
|
settingSources?: SettingSource[];
|
|
strictMcpConfig?: boolean;
|
|
extraArgs?: Record<string, string | null>;
|
|
};
|
|
|
|
function loadProviderSettings(): ProviderSettings {
|
|
const globalSettings = readSettingsFile(GLOBAL_SETTINGS_PATH);
|
|
const projectSettings = readSettingsFile(PROJECT_SETTINGS_PATH);
|
|
return { ...globalSettings, ...projectSettings };
|
|
}
|
|
|
|
function readSettingsFile(filePath: string): ProviderSettings {
|
|
if (!existsSync(filePath)) return {};
|
|
try {
|
|
const raw = readFileSync(filePath, "utf-8");
|
|
const parsed = JSON.parse(raw) as Record<string, unknown>;
|
|
const settingsBlock =
|
|
(parsed["claudeAgentSdkProvider"] as Record<string, unknown> | undefined) ??
|
|
(parsed["claude-agent-sdk-provider"] as Record<string, unknown> | undefined) ??
|
|
(parsed["claudeAgentSdk"] as Record<string, unknown> | undefined);
|
|
if (!settingsBlock || typeof settingsBlock !== "object") return {};
|
|
|
|
const appendSystemPrompt =
|
|
typeof settingsBlock["appendSystemPrompt"] === "boolean"
|
|
? settingsBlock["appendSystemPrompt"]
|
|
: undefined;
|
|
|
|
const settingSourcesRaw = settingsBlock["settingSources"];
|
|
const settingSources =
|
|
Array.isArray(settingSourcesRaw) &&
|
|
settingSourcesRaw.every(
|
|
(v) => typeof v === "string" && (v === "user" || v === "project" || v === "local"),
|
|
)
|
|
? (settingSourcesRaw as SettingSource[])
|
|
: undefined;
|
|
|
|
const strictMcpConfig =
|
|
typeof settingsBlock["strictMcpConfig"] === "boolean"
|
|
? settingsBlock["strictMcpConfig"]
|
|
: undefined;
|
|
|
|
const extraArgsRaw = settingsBlock["extraArgs"];
|
|
const extraArgs =
|
|
extraArgsRaw && typeof extraArgsRaw === "object" && !Array.isArray(extraArgsRaw)
|
|
? (Object.fromEntries(
|
|
Object.entries(extraArgsRaw as Record<string, unknown>)
|
|
.filter(([, v]) => typeof v === "string" || v === null)
|
|
.map(([k, v]) => [k, v as string | null]),
|
|
) as Record<string, string | null>)
|
|
: undefined;
|
|
|
|
return { appendSystemPrompt, settingSources, strictMcpConfig, extraArgs };
|
|
} catch {
|
|
return {};
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// System Prompt Helpers
|
|
// =============================================================================
|
|
|
|
function extractSkillsAppend(systemPrompt?: string): string | undefined {
|
|
if (!systemPrompt) return undefined;
|
|
const startMarker = "The following skills provide specialized instructions for specific tasks.";
|
|
const endMarker = "</available_skills>";
|
|
const startIndex = systemPrompt.indexOf(startMarker);
|
|
if (startIndex === -1) return undefined;
|
|
const endIndex = systemPrompt.indexOf(endMarker, startIndex);
|
|
if (endIndex === -1) return undefined;
|
|
const skillsBlock = systemPrompt.slice(startIndex, endIndex + endMarker.length).trim();
|
|
return rewriteSkillsLocations(skillsBlock);
|
|
}
|
|
|
|
function rewriteSkillsLocations(skillsBlock: string): string {
|
|
return skillsBlock.replace(/<location>([^<]+)<\/location>/g, (_match, location: string) => {
|
|
let rewritten = location;
|
|
if (location.startsWith(GLOBAL_SKILLS_ROOT)) {
|
|
const relPath = relative(GLOBAL_SKILLS_ROOT, location).replace(/^\.+/, "");
|
|
rewritten = `${SKILLS_ALIAS_GLOBAL}/${relPath}`.replace(/\/\/+/g, "/");
|
|
} else if (location.startsWith(PROJECT_SKILLS_ROOT)) {
|
|
const relPath = relative(PROJECT_SKILLS_ROOT, location).replace(/^\.+/, "");
|
|
rewritten = `${SKILLS_ALIAS_PROJECT}/${relPath}`.replace(/\/\/+/g, "/");
|
|
}
|
|
return `<location>${rewritten}</location>`;
|
|
});
|
|
}
|
|
|
|
function resolveAgentsMdPath(): string | undefined {
|
|
const fromCwd = findAgentsMdInParents(process.cwd());
|
|
if (fromCwd) return fromCwd;
|
|
if (existsSync(GLOBAL_AGENTS_PATH)) return GLOBAL_AGENTS_PATH;
|
|
return undefined;
|
|
}
|
|
|
|
function findAgentsMdInParents(startDir: string): string | undefined {
|
|
let current = resolve(startDir);
|
|
while (true) {
|
|
const candidate = join(current, "AGENTS.md");
|
|
if (existsSync(candidate)) return candidate;
|
|
const parent = dirname(current);
|
|
if (parent === current) break;
|
|
current = parent;
|
|
}
|
|
return undefined;
|
|
}
|
|
|
|
function extractAgentsAppend(): string | undefined {
|
|
const agentsPath = resolveAgentsMdPath();
|
|
if (!agentsPath) return undefined;
|
|
try {
|
|
const content = readFileSync(agentsPath, "utf-8").trim();
|
|
if (!content) return undefined;
|
|
const sanitized = sanitizeAgentsContent(content);
|
|
return sanitized.length > 0 ? `# CLAUDE.md\n\n${sanitized}` : undefined;
|
|
} catch {
|
|
return undefined;
|
|
}
|
|
}
|
|
|
|
function sanitizeAgentsContent(content: string): string {
|
|
let s = content;
|
|
s = s.replace(/~\/\.pi\b/gi, "~/.claude");
|
|
s = s.replace(/(^|[\s'"`])\.pi\//g, "$1.claude/");
|
|
s = s.replace(/\b\.pi\b/gi, ".claude");
|
|
s = s.replace(/\bpi\b/gi, "environment");
|
|
return s;
|
|
}
|
|
|
|
// =============================================================================
|
|
// Thinking Budget
|
|
// =============================================================================
|
|
|
|
type ThinkingLevel = NonNullable<SimpleStreamOptions["reasoning"]>;
|
|
type NonXhighThinkingLevel = Exclude<ThinkingLevel, "xhigh">;
|
|
|
|
const DEFAULT_THINKING_BUDGETS: Record<NonXhighThinkingLevel, number> = {
|
|
minimal: 2048,
|
|
low: 8192,
|
|
medium: 16384,
|
|
high: 31999,
|
|
};
|
|
|
|
const OPUS_46_THINKING_BUDGETS: Record<ThinkingLevel, number> = {
|
|
minimal: 2048,
|
|
low: 8192,
|
|
medium: 31999,
|
|
high: 63999,
|
|
xhigh: 63999,
|
|
};
|
|
|
|
function mapThinkingTokens(
|
|
reasoning?: ThinkingLevel,
|
|
modelId?: string,
|
|
thinkingBudgets?: SimpleStreamOptions["thinkingBudgets"],
|
|
): number | undefined {
|
|
if (!reasoning) return undefined;
|
|
const isOpus46 = modelId?.includes("opus-4-6") || modelId?.includes("opus-4.6");
|
|
if (isOpus46) return OPUS_46_THINKING_BUDGETS[reasoning];
|
|
const effectiveReasoning: NonXhighThinkingLevel = reasoning === "xhigh" ? "high" : reasoning;
|
|
const customBudgets = thinkingBudgets as Partial<Record<NonXhighThinkingLevel, number>> | undefined;
|
|
const customBudget = customBudgets?.[effectiveReasoning];
|
|
if (typeof customBudget === "number" && Number.isFinite(customBudget) && customBudget > 0) {
|
|
return customBudget;
|
|
}
|
|
return DEFAULT_THINKING_BUDGETS[effectiveReasoning];
|
|
}
|
|
|
|
// =============================================================================
|
|
// Misc Helpers
|
|
// =============================================================================
|
|
|
|
function parsePartialJson(input: string, fallback: Record<string, unknown>): Record<string, unknown> {
|
|
if (!input) return fallback;
|
|
try {
|
|
return JSON.parse(input);
|
|
} catch {
|
|
return fallback;
|
|
}
|
|
}
|
|
|
|
function mapStopReason(reason: string | undefined): "stop" | "length" | "toolUse" {
|
|
switch (reason) {
|
|
case "tool_use":
|
|
return "toolUse";
|
|
case "max_tokens":
|
|
return "length";
|
|
default:
|
|
return "stop";
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// Main Stream Function
|
|
// =============================================================================
|
|
|
|
function streamClaudeAgentSdk(
|
|
model: Model<any>,
|
|
context: Context,
|
|
options?: SimpleStreamOptions,
|
|
): AssistantMessageEventStream {
|
|
const stream = createAssistantMessageEventStream();
|
|
|
|
(async () => {
|
|
const output: AssistantMessage = {
|
|
role: "assistant",
|
|
content: [],
|
|
api: model.api,
|
|
provider: model.provider,
|
|
model: model.id,
|
|
usage: {
|
|
input: 0,
|
|
output: 0,
|
|
cacheRead: 0,
|
|
cacheWrite: 0,
|
|
totalTokens: 0,
|
|
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
|
},
|
|
stopReason: "stop",
|
|
timestamp: Date.now(),
|
|
};
|
|
|
|
let sdkQuery: ReturnType<typeof query> | undefined;
|
|
let wasAborted = false;
|
|
const requestAbort = () => {
|
|
if (!sdkQuery) return;
|
|
void sdkQuery.interrupt().catch(() => {
|
|
try {
|
|
sdkQuery?.close();
|
|
} catch {
|
|
// ignore
|
|
}
|
|
});
|
|
};
|
|
const onAbort = () => {
|
|
wasAborted = true;
|
|
requestAbort();
|
|
};
|
|
if (options?.signal) {
|
|
if (options.signal.aborted) onAbort();
|
|
else options.signal.addEventListener("abort", onAbort, { once: true });
|
|
}
|
|
|
|
const blocks = output.content as Array<
|
|
| { type: "text"; text: string; index: number }
|
|
| { type: "thinking"; thinking: string; thinkingSignature?: string; index: number }
|
|
| {
|
|
type: "toolCall";
|
|
id: string;
|
|
name: string;
|
|
arguments: Record<string, unknown>;
|
|
partialJson: string;
|
|
index: number;
|
|
}
|
|
>;
|
|
|
|
let started = false;
|
|
let sawStreamEvent = false;
|
|
let sawToolCall = false;
|
|
let shouldStopEarly = false;
|
|
|
|
try {
|
|
const { sdkTools, customTools, customToolNameToSdk, customToolNameToPi } =
|
|
resolveSdkTools(context);
|
|
|
|
const cwd = (options as { cwd?: string } | undefined)?.cwd ?? process.cwd();
|
|
const mcpServers = buildCustomToolServers(customTools);
|
|
const providerSettings = loadProviderSettings();
|
|
const appendSystemPrompt = providerSettings.appendSystemPrompt !== false;
|
|
const agentsAppend = appendSystemPrompt ? extractAgentsAppend() : undefined;
|
|
const skillsAppend = appendSystemPrompt
|
|
? extractSkillsAppend(context.systemPrompt)
|
|
: undefined;
|
|
const appendParts = [agentsAppend, skillsAppend].filter((p): p is string => Boolean(p));
|
|
const systemPromptAppend = appendParts.length > 0 ? appendParts.join("\n\n") : undefined;
|
|
const allowSkillAliasRewrite = Boolean(skillsAppend);
|
|
|
|
const settingSources: SettingSource[] | undefined = appendSystemPrompt
|
|
? undefined
|
|
: providerSettings.settingSources ?? ["user", "project"];
|
|
|
|
const strictMcpConfigEnabled = !appendSystemPrompt && providerSettings.strictMcpConfig !== false;
|
|
const generatedExtraArgs = strictMcpConfigEnabled ? { "strict-mcp-config": null } : undefined;
|
|
const extraArgs = { ...generatedExtraArgs, ...providerSettings.extraArgs };
|
|
|
|
// ----------------------------------------------------------------
|
|
// Decide: resume existing session or start fresh?
|
|
// ----------------------------------------------------------------
|
|
const convKey = getConversationKey(context, options);
|
|
const existingSession = sessionRegistry.get(convKey);
|
|
|
|
let prompt: AsyncIterable<SDKUserMessage>;
|
|
let claudeCodeSessionId: string;
|
|
let newSentMsgCount: number;
|
|
let isResume: boolean;
|
|
|
|
if (existingSession && context.messages.length > existingSession.sentMsgCount) {
|
|
// Continuation: resume and inject new tool results / user messages.
|
|
claudeCodeSessionId = existingSession.claudeCodeSessionId;
|
|
newSentMsgCount = context.messages.length;
|
|
isResume = true;
|
|
prompt = buildResumeStream(context, existingSession.sentMsgCount, claudeCodeSessionId);
|
|
} else {
|
|
// Fresh session (first turn, or after session was cleaned up).
|
|
claudeCodeSessionId = randomUUID();
|
|
newSentMsgCount = context.messages.length;
|
|
isResume = false;
|
|
const fallbackBlocks = buildFallbackTextBlocks(context, customToolNameToSdk);
|
|
prompt = buildFallbackStream(fallbackBlocks);
|
|
}
|
|
|
|
const queryOptions: NonNullable<Parameters<typeof query>[0]["options"]> = {
|
|
cwd,
|
|
tools: sdkTools,
|
|
permissionMode: "dontAsk",
|
|
includePartialMessages: true,
|
|
canUseTool: async () => ({
|
|
behavior: "deny",
|
|
message: TOOL_EXECUTION_DENIED_MESSAGE,
|
|
}),
|
|
systemPrompt: {
|
|
type: "preset",
|
|
preset: "claude_code",
|
|
append: systemPromptAppend ?? undefined,
|
|
},
|
|
...(settingSources ? { settingSources } : {}),
|
|
...(Object.keys(extraArgs).length > 0 ? { extraArgs } : {}),
|
|
...(mcpServers ? { mcpServers } : {}),
|
|
// Session management
|
|
...(isResume
|
|
? { resume: claudeCodeSessionId }
|
|
: { sessionId: claudeCodeSessionId, persistSession: true }),
|
|
};
|
|
|
|
const maxThinkingTokens = mapThinkingTokens(options?.reasoning, model.id, options?.thinkingBudgets);
|
|
if (maxThinkingTokens != null) {
|
|
queryOptions.maxThinkingTokens = maxThinkingTokens;
|
|
}
|
|
|
|
// Register / update session state BEFORE the query so we have it
|
|
// even if an error occurs after the first message_stop.
|
|
sessionRegistry.set(convKey, {
|
|
claudeCodeSessionId,
|
|
sentMsgCount: newSentMsgCount,
|
|
modelId: model.id,
|
|
});
|
|
|
|
sdkQuery = query({ prompt, options: queryOptions });
|
|
|
|
if (wasAborted) requestAbort();
|
|
|
|
for await (const message of sdkQuery) {
|
|
if (!started) {
|
|
stream.push({ type: "start", partial: output });
|
|
started = true;
|
|
}
|
|
|
|
switch (message.type) {
|
|
case "stream_event": {
|
|
sawStreamEvent = true;
|
|
const event = (message as SDKMessage & { event: any }).event;
|
|
|
|
if (event?.type === "message_start") {
|
|
const usage = event.message?.usage;
|
|
output.usage.input = usage?.input_tokens ?? 0;
|
|
output.usage.output = usage?.output_tokens ?? 0;
|
|
output.usage.cacheRead = usage?.cache_read_input_tokens ?? 0;
|
|
output.usage.cacheWrite = usage?.cache_creation_input_tokens ?? 0;
|
|
output.usage.totalTokens =
|
|
output.usage.input +
|
|
output.usage.output +
|
|
output.usage.cacheRead +
|
|
output.usage.cacheWrite;
|
|
calculateCost(model, output.usage);
|
|
break;
|
|
}
|
|
|
|
if (event?.type === "content_block_start") {
|
|
if (event.content_block?.type === "text") {
|
|
blocks.push({ type: "text", text: "", index: event.index });
|
|
stream.push({
|
|
type: "text_start",
|
|
contentIndex: output.content.length - 1,
|
|
partial: output,
|
|
});
|
|
} else if (event.content_block?.type === "thinking") {
|
|
blocks.push({
|
|
type: "thinking",
|
|
thinking: "",
|
|
thinkingSignature: "",
|
|
index: event.index,
|
|
});
|
|
stream.push({
|
|
type: "thinking_start",
|
|
contentIndex: output.content.length - 1,
|
|
partial: output,
|
|
});
|
|
} else if (event.content_block?.type === "tool_use") {
|
|
sawToolCall = true;
|
|
blocks.push({
|
|
type: "toolCall",
|
|
id: event.content_block.id,
|
|
name: mapToolName(event.content_block.name, customToolNameToPi),
|
|
arguments: (event.content_block.input as Record<string, unknown>) ?? {},
|
|
partialJson: "",
|
|
index: event.index,
|
|
});
|
|
stream.push({
|
|
type: "toolcall_start",
|
|
contentIndex: output.content.length - 1,
|
|
partial: output,
|
|
});
|
|
}
|
|
break;
|
|
}
|
|
|
|
if (event?.type === "content_block_delta") {
|
|
if (event.delta?.type === "text_delta") {
|
|
const idx = blocks.findIndex((b) => b.index === event.index);
|
|
const block = blocks[idx];
|
|
if (block?.type === "text") {
|
|
block.text += event.delta.text;
|
|
stream.push({
|
|
type: "text_delta",
|
|
contentIndex: idx,
|
|
delta: event.delta.text,
|
|
partial: output,
|
|
});
|
|
}
|
|
} else if (event.delta?.type === "thinking_delta") {
|
|
const idx = blocks.findIndex((b) => b.index === event.index);
|
|
const block = blocks[idx];
|
|
if (block?.type === "thinking") {
|
|
block.thinking += event.delta.thinking;
|
|
stream.push({
|
|
type: "thinking_delta",
|
|
contentIndex: idx,
|
|
delta: event.delta.thinking,
|
|
partial: output,
|
|
});
|
|
}
|
|
} else if (event.delta?.type === "input_json_delta") {
|
|
const idx = blocks.findIndex((b) => b.index === event.index);
|
|
const block = blocks[idx];
|
|
if (block?.type === "toolCall") {
|
|
block.partialJson += event.delta.partial_json;
|
|
block.arguments = parsePartialJson(block.partialJson, block.arguments);
|
|
stream.push({
|
|
type: "toolcall_delta",
|
|
contentIndex: idx,
|
|
delta: event.delta.partial_json,
|
|
partial: output,
|
|
});
|
|
}
|
|
} else if (event.delta?.type === "signature_delta") {
|
|
const idx = blocks.findIndex((b) => b.index === event.index);
|
|
const block = blocks[idx];
|
|
if (block?.type === "thinking") {
|
|
block.thinkingSignature =
|
|
(block.thinkingSignature ?? "") + event.delta.signature;
|
|
}
|
|
}
|
|
break;
|
|
}
|
|
|
|
if (event?.type === "content_block_stop") {
|
|
const idx = blocks.findIndex((b) => b.index === event.index);
|
|
const block = blocks[idx];
|
|
if (!block) break;
|
|
delete (block as any).index;
|
|
if (block.type === "text") {
|
|
stream.push({
|
|
type: "text_end",
|
|
contentIndex: idx,
|
|
content: block.text,
|
|
partial: output,
|
|
});
|
|
} else if (block.type === "thinking") {
|
|
stream.push({
|
|
type: "thinking_end",
|
|
contentIndex: idx,
|
|
content: block.thinking,
|
|
partial: output,
|
|
});
|
|
} else if (block.type === "toolCall") {
|
|
sawToolCall = true;
|
|
block.arguments = mapToolArgs(
|
|
block.name,
|
|
parsePartialJson(block.partialJson, block.arguments),
|
|
allowSkillAliasRewrite,
|
|
);
|
|
delete (block as any).partialJson;
|
|
stream.push({
|
|
type: "toolcall_end",
|
|
contentIndex: idx,
|
|
toolCall: block,
|
|
partial: output,
|
|
});
|
|
}
|
|
break;
|
|
}
|
|
|
|
if (event?.type === "message_delta") {
|
|
output.stopReason = mapStopReason(event.delta?.stop_reason);
|
|
const usage = event.usage ?? {};
|
|
if (usage.input_tokens != null) output.usage.input = usage.input_tokens;
|
|
if (usage.output_tokens != null) output.usage.output = usage.output_tokens;
|
|
if (usage.cache_read_input_tokens != null)
|
|
output.usage.cacheRead = usage.cache_read_input_tokens;
|
|
if (usage.cache_creation_input_tokens != null)
|
|
output.usage.cacheWrite = usage.cache_creation_input_tokens;
|
|
output.usage.totalTokens =
|
|
output.usage.input +
|
|
output.usage.output +
|
|
output.usage.cacheRead +
|
|
output.usage.cacheWrite;
|
|
calculateCost(model, output.usage);
|
|
break;
|
|
}
|
|
|
|
if (event?.type === "message_stop" && sawToolCall) {
|
|
output.stopReason = "toolUse";
|
|
shouldStopEarly = true;
|
|
break;
|
|
}
|
|
|
|
break;
|
|
}
|
|
|
|
case "result": {
|
|
if (!sawStreamEvent && message.subtype === "success") {
|
|
output.content.push({ type: "text", text: (message as any).result || "" });
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
|
|
if (shouldStopEarly) break;
|
|
}
|
|
|
|
if (wasAborted || options?.signal?.aborted) {
|
|
output.stopReason = "aborted";
|
|
output.errorMessage = "Operation aborted";
|
|
stream.push({ type: "error", reason: "aborted", error: output });
|
|
stream.end();
|
|
return;
|
|
}
|
|
|
|
// Clean up session registry when the conversation ends naturally
|
|
// (no tool calls → final answer → conversation over).
|
|
if (output.stopReason === "stop" || output.stopReason === "length") {
|
|
sessionRegistry.delete(convKey);
|
|
}
|
|
|
|
stream.push({
|
|
type: "done",
|
|
reason:
|
|
output.stopReason === "toolUse"
|
|
? "toolUse"
|
|
: output.stopReason === "length"
|
|
? "length"
|
|
: "stop",
|
|
message: output,
|
|
});
|
|
stream.end();
|
|
} catch (error) {
|
|
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
|
|
output.errorMessage = error instanceof Error ? error.message : String(error);
|
|
stream.push({
|
|
type: "error",
|
|
reason: output.stopReason as "aborted" | "error",
|
|
error: output,
|
|
});
|
|
stream.end();
|
|
} finally {
|
|
if (options?.signal) {
|
|
options.signal.removeEventListener("abort", onAbort);
|
|
}
|
|
sdkQuery?.close();
|
|
}
|
|
})();
|
|
|
|
return stream;
|
|
}
|
|
|
|
// =============================================================================
|
|
// Extension Entry Point
|
|
// =============================================================================
|
|
|
|
export default function (pi: ExtensionAPI) {
|
|
pi.registerProvider(PROVIDER_ID, {
|
|
baseUrl: "claude-agent-sdk",
|
|
apiKey: "ANTHROPIC_API_KEY",
|
|
api: "claude-agent-sdk",
|
|
models: MODELS,
|
|
streamSimple: streamClaudeAgentSdk,
|
|
});
|
|
}
|