pi config update

This commit is contained in:
Jonas H
2026-03-19 07:58:49 +01:00
parent a3c9183485
commit 871caa5adc
24 changed files with 6198 additions and 555 deletions

View File

@@ -118,11 +118,11 @@ function killOtherSessions(pids: number[]): number {
function statusLabel(account: Account | "unknown"): string {
switch (account) {
case "personal":
return "🏠 personal";
return " personal";
case "work":
return "💼 work";
return "󰃖 work";
default:
return " claude";
return " claude";
}
}
@@ -132,13 +132,29 @@ export default function (pi: ExtensionAPI) {
let currentAccount: Account | "unknown" = "unknown";
pi.on("session_start", async (_event, ctx) => {
// Proper-lockfile creates auth.json.lock as a *directory* (atomic mkdir).
// If a regular file exists at that path (e.g. left by an older pi version),
// rmdir fails with ENOTDIR → lock acquisition throws → loadError is set →
// credentials are never persisted after /login. Delete the stale file and
// reload so this session has working auth persistence.
const lockPath = AUTH_JSON + ".lock";
try {
const stat = fs.statSync(lockPath);
if (stat.isFile()) {
fs.unlinkSync(lockPath);
ctx.modelRegistry.authStorage.reload();
}
} catch {
// lock doesn't exist or we can't stat it — nothing to fix
}
currentAccount = getCurrentAccount();
ctx.ui.setStatus("claude-account", statusLabel(currentAccount));
});
pi.registerCommand("switch-claude", {
description:
"Switch between personal (🏠) and work (💼) Claude accounts. Use 'save <name>' to save current login as a profile.",
"Switch between personal () and work (󰃖) Claude accounts. Use 'save <name>' to save current login as a profile.",
handler: async (args, ctx) => {
const authStorage = ctx.modelRegistry.authStorage;
const trimmed = args?.trim() ?? "";
@@ -178,15 +194,15 @@ export default function (pi: ExtensionAPI) {
if (trimmed === "personal" || trimmed === "work") {
newAccount = trimmed;
} else if (trimmed === "") {
const personalLabel = `🏠 personal${currentAccount === "personal" ? " ← current" : ""}${!hasProfile("personal") ? " (no profile saved)" : ""}`;
const workLabel = `💼 work${currentAccount === "work" ? " ← current" : ""}${!hasProfile("work") ? " (no profile saved)" : ""}`;
const personalLabel = ` personal${currentAccount === "personal" ? " ← current" : ""}${!hasProfile("personal") ? " (no profile saved)" : ""}`;
const workLabel = `󰃖 work${currentAccount === "work" ? " ← current" : ""}${!hasProfile("work") ? " (no profile saved)" : ""}`;
const accountChoice = await ctx.ui.select(
"Switch Claude account:",
[personalLabel, workLabel],
);
if (accountChoice === undefined) return;
newAccount = accountChoice.startsWith("🏠") ? "personal" : "work";
newAccount = accountChoice.startsWith("") ? "personal" : "work";
} else {
ctx.ui.notify(
"Usage: /switch-claude [personal|work|save <name>]",
@@ -276,6 +292,7 @@ export default function (pi: ExtensionAPI) {
currentAccount = newAccount;
setMarker(currentAccount);
ctx.ui.setStatus("claude-account", statusLabel(currentAccount));
pi.events.emit("claude-account:switched", { account: newAccount });
ctx.ui.notify(
`Switched to ${statusLabel(newAccount)}`,
"info",

View File

@@ -0,0 +1,191 @@
/**
* Footer Display Extension
*
* Replaces the built-in pi footer with a single clean line that assembles
* status from all other extensions:
*
* \uEF85 | S ⣿⣶⣀⣀⣀ 34% 2h 55m | W ⣿⣿⣷⣀⣀ 68% ⟳ Fri 09:00 | C ⣿⣀⣀⣀⣀ 20% | Sonnet 4.6 | rust-analyzer | MCP: 1/2
*
* Status sources:
* "claude-account" — set by claude-account-switch.ts → just the icon
* "usage-bars" — set by usage-bars extension → S/W bars (pass-through)
* ctx.getContextUsage() → C bar (rendered here)
* ctx.model → model short name
* "lsp" — set by lsp-pi extension → strip "LSP " prefix
* "mcp" — set by pi-mcp-adapter → strip " servers" suffix
*/
import os from "os";
import path from "path";
import type { ExtensionAPI } from "@mariozechner/pi-coding-agent";
import { truncateToWidth } from "@mariozechner/pi-tui";
// ---------------------------------------------------------------------------
// Braille gradient bar — used here only for the context (C) bar
// ---------------------------------------------------------------------------
const BRAILLE_GRADIENT = "\u28C0\u28C4\u28E4\u28E6\u28F6\u28F7\u28FF";
const BRAILLE_EMPTY = "\u28C0";
const BAR_WIDTH = 5;
function renderBrailleBar(theme: any, value: number): string {
const v = Math.max(0, Math.min(100, Math.round(value)));
const levels = BRAILLE_GRADIENT.length - 1;
const totalSteps = BAR_WIDTH * levels;
const filledSteps = Math.round((v / 100) * totalSteps);
const full = Math.floor(filledSteps / levels);
const partial = filledSteps % levels;
const empty = BAR_WIDTH - full - (partial ? 1 : 0);
const color = v >= 90 ? "error" : v >= 70 ? "warning" : "success";
const filled = BRAILLE_GRADIENT[BRAILLE_GRADIENT.length - 1]!.repeat(Math.max(0, full));
const partialChar = partial ? BRAILLE_GRADIENT[partial]! : "";
const emptyChars = BRAILLE_EMPTY.repeat(Math.max(0, empty));
return theme.fg(color, filled + partialChar) + theme.fg("dim", emptyChars);
}
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
function stripAnsi(text: string): string {
return text.replace(/\x1b\[[0-9;]*m/g, "").replace(/\x1b\][^\x07]*\x07/g, "");
}
function getModelShortName(modelId: string): string {
// claude-haiku-4-5 → "Haiku 4.5", claude-sonnet-4-6 → "Sonnet 4.6"
const m = modelId.match(/^claude-([a-z]+)-([\d]+(?:-[\d]+)*)(?:-\d{8})?$/);
if (m) {
const family = m[1]!.charAt(0).toUpperCase() + m[1]!.slice(1);
return `${family} ${m[2]!.replace(/-/g, ".")}`;
}
// claude-3-5-sonnet, claude-3-opus, etc.
const m2 = modelId.match(/^claude-[\d-]+-([a-z]+)/);
if (m2) return m2[1]!.charAt(0).toUpperCase() + m2[1]!.slice(1);
return modelId.replace(/^claude-/, "");
}
// Nerd Font codepoints matched to what claude-account-switch.ts emits
const ICON_PERSONAL = "\uEF85"; // U+EF85 — home
const ICON_WORK = "\uDB80\uDCD6"; // U+F00D6 — briefcase (surrogate pair)
const ICON_UNKNOWN = "\uF420"; // U+F420 — claude default
export default function (pi: ExtensionAPI) {
let ctx: any = null;
let tuiRef: any = null;
let footerDataRef: any = null;
// ---------------------------------------------------------------------------
// Footer line builder — called on every render
// ---------------------------------------------------------------------------
function buildFooterLine(theme: any): string {
const sep = theme.fg("dim", " · ");
const pipeSep = theme.fg("dim", " | ");
const parts: string[] = [];
const statuses: ReadonlyMap<string, string> =
footerDataRef?.getExtensionStatuses?.() ?? new Map();
// 1. Current working directory
const home = os.homedir();
const cwd = process.cwd();
const dir = cwd.startsWith(home)
? "~" + path.sep + path.relative(home, cwd)
: cwd;
parts.push(theme.fg("dim", dir));
// 2. Account icon
const acctRaw = statuses.get("claude-account");
if (acctRaw !== undefined) {
const clean = stripAnsi(acctRaw).trim();
let icon: string;
if (clean.includes("personal")) icon = ICON_PERSONAL;
else if (clean.includes("work")) icon = ICON_WORK;
else icon = ICON_UNKNOWN;
parts.push(theme.fg("dim", icon));
}
// 3. S / W usage bars + C bar — joined as one |-separated block
const usageRaw = statuses.get("usage-bars");
const contextUsage = ctx?.getContextUsage?.();
{
let block = usageRaw ?? "";
if (contextUsage && contextUsage.percent !== null) {
const pct = Math.round(contextUsage.percent);
const cBar =
theme.fg("muted", "C ") +
renderBrailleBar(theme, pct) +
" " +
theme.fg("dim", `${pct}%`);
block = block ? block + pipeSep + cBar : cBar;
}
if (block) parts.push(block);
}
// 4. Model short name
const modelId = ctx?.model?.id;
if (modelId) parts.push(theme.fg("dim", getModelShortName(modelId)));
// 5. LSP — strip "LSP" prefix and activity dot
const lspRaw = statuses.get("lsp");
if (lspRaw) {
const clean = stripAnsi(lspRaw).trim().replace(/^LSP\s*[•·]?\s*/i, "").trim();
if (clean) parts.push(theme.fg("dim", clean));
}
// 6. MCP — strip " servers" suffix
const mcpRaw = statuses.get("mcp");
if (mcpRaw) {
const clean = stripAnsi(mcpRaw).trim().replace(/\s*servers?.*$/, "").trim();
if (clean) parts.push(theme.fg("dim", clean));
}
return parts.join(sep);
}
// ---------------------------------------------------------------------------
// Footer installation
// ---------------------------------------------------------------------------
function installFooter(_ctx: any) {
if (!_ctx?.hasUI) return;
_ctx.ui.setFooter((_tui: any, theme: any, footerData: any) => {
tuiRef = _tui;
footerDataRef = footerData;
const unsub = footerData.onBranchChange(() => _tui.requestRender());
return {
dispose: unsub,
invalidate() {},
render(width: number): string[] {
return [truncateToWidth(buildFooterLine(theme) || "", width)];
},
};
});
}
// ---------------------------------------------------------------------------
// Event handlers
// ---------------------------------------------------------------------------
pi.on("session_start", async (_event, _ctx) => {
ctx = _ctx;
installFooter(_ctx);
});
pi.on("session_shutdown", (_event, _ctx) => {
if (_ctx?.hasUI) _ctx.ui.setFooter(undefined);
});
// Re-render after turns so context usage stays current
pi.on("turn_end", (_event, _ctx) => {
ctx = _ctx;
if (tuiRef) tuiRef.requestRender();
});
// Re-render when model changes (updates model name in footer)
pi.on("model_select", (_event, _ctx) => {
ctx = _ctx;
if (tuiRef) tuiRef.requestRender();
});
// Re-render when account switches (usage:update comes from usage-bars setStatus which
// already triggers a render, but account icon needs a nudge too)
pi.events.on("claude-account:switched", () => {
if (tuiRef) tuiRef.requestRender();
});
}

View File

@@ -0,0 +1,99 @@
/**
* Git Checkout Guard Extension
*
* Prevents models from using `git checkout` or `git restore` to silently
* discard uncommitted changes in files.
*/
import type { ExtensionAPI } from "@mariozechner/pi-coding-agent";
import { isToolCallEventType } from "@mariozechner/pi-coding-agent";
import { execSync } from "child_process";
/**
* Parse file paths from a git checkout/restore command.
* Returns null if the command doesn't look like a file-restore operation.
*/
function parseFileRestoreArgs(command: string): string[] | null {
// Normalize whitespace
const cmd = command.trim().replace(/\s+/g, " ");
// Match: git checkout -- <files>
// Match: git checkout <ref> -- <files>
const checkoutDashDash = cmd.match(/\bgit\s+checkout\b.*?\s--\s+(.+)/);
if (checkoutDashDash) {
return checkoutDashDash[1].trim().split(/\s+/);
}
// Match: git restore [--staged] [--source=<ref>] <files>
// (git restore always operates on files)
const restore = cmd.match(/\bgit\s+restore\s+(.+)/);
if (restore) {
// Filter out flags like --staged, --source=..., --worktree, --patch
const args = restore[1].trim().split(/\s+/);
const files = args.filter((a) => !a.startsWith("-"));
return files.length > 0 ? files : null;
}
return null;
}
/**
* Check which of the given file paths have uncommitted changes (staged or unstaged).
* Returns the subset that are dirty.
*/
function getDirtyFiles(files: string[], cwd: string): string[] {
const dirty: string[] = [];
for (const file of files) {
try {
// --porcelain output is empty for clean files
const out = execSync(`git status --porcelain -- ${JSON.stringify(file)}`, {
cwd,
encoding: "utf8",
stdio: ["ignore", "pipe", "ignore"],
}).trim();
if (out.length > 0) {
dirty.push(file);
}
} catch {
// Not a git repo or other error — skip
}
}
return dirty;
}
export default function (pi: ExtensionAPI) {
pi.on("tool_call", async (event, ctx) => {
if (!isToolCallEventType("bash", event)) return undefined;
const command: string = event.input.command ?? "";
const files = parseFileRestoreArgs(command);
if (!files || files.length === 0) return undefined;
const cwd = process.cwd();
const dirty = getDirtyFiles(files, cwd);
if (dirty.length === 0) return undefined; // nothing to protect
const fileList = dirty.map((f) => `${f}`).join("\n");
if (!ctx.hasUI) {
return {
block: true,
reason: `git-checkout-guard: the following files have uncommitted changes and cannot be silently reverted:\n${fileList}\nShow the diff to the user and ask for explicit confirmation first.`,
};
}
const choice = await ctx.ui.select(
`⚠️ git-checkout-guard\n\nThe command:\n ${command}\n\nwould discard uncommitted changes in:\n${fileList}\n\nProceed?`,
["No, cancel (show diff instead)", "Yes, discard changes anyway"],
);
if (choice !== "Yes, discard changes anyway") {
return {
block: true,
reason: `Blocked by git-checkout-guard. Run \`git diff ${dirty.join(" ")}\` and review before discarding.`,
};
}
return undefined;
});
}

View File

@@ -1,293 +0,0 @@
/**
* llama-server Schema Sanitization Proxy
*
* llama-server strictly validates JSON Schema and rejects any schema node
* that lacks a `type` field. Some of pi's built-in tools (e.g. `subagent`)
* have complex union-type parameters represented as `{"description": "..."}` with
* no `type`, which causes llama-server to return a 400 error.
*
* This extension provides an optional tiny local HTTP proxy on port 8081 that:
* 1. Intercepts outgoing OpenAI-compatible API calls
* 2. Walks tool schemas and adds `"type": "string"` to any schema node
* that is missing a type declaration
* 3. Forwards the fixed request to llama-server on port 8080
* 4. Streams the response back transparently
*
* It also overrides the `llama-cpp` provider's baseUrl to point at the proxy,
* so no changes to models.json are needed (beyond what's already there).
*
* Use `/llama-proxy` command to toggle the proxy on/off. Off by default.
*/
import type { ExtensionAPI } from "@mariozechner/pi-coding-agent";
import * as http from "http";
import { execSync } from "child_process";
const PROXY_PORT = 8081;
const TARGET_HOST = "127.0.0.1";
const TARGET_PORT = 8080;
// ---------------------------------------------------------------------------
// Schema sanitizer
// ---------------------------------------------------------------------------
/**
* Recursively walk a JSON Schema object and add `"type": "string"` to any
* node that has no `type` and no composition keywords (oneOf/anyOf/allOf/$ref).
* This satisfies llama-server's strict validation without breaking valid nodes.
*/
function sanitizeSchema(schema: unknown): unknown {
if (!schema || typeof schema !== "object") return schema;
if (Array.isArray(schema)) return schema.map(sanitizeSchema);
const obj = schema as Record<string, unknown>;
const result: Record<string, unknown> = {};
for (const [key, value] of Object.entries(obj)) {
if (key === "properties" && value && typeof value === "object" && !Array.isArray(value)) {
result[key] = Object.fromEntries(
Object.entries(value as Record<string, unknown>).map(([k, v]) => [k, sanitizeSchema(v)]),
);
} else if (key === "items") {
result[key] = sanitizeSchema(value);
} else if (key === "additionalProperties" && value && typeof value === "object") {
result[key] = sanitizeSchema(value);
} else if (
(key === "oneOf" || key === "anyOf" || key === "allOf") &&
Array.isArray(value)
) {
result[key] = value.map(sanitizeSchema);
} else {
result[key] = value;
}
}
// If this schema node has no type and no composition keywords, default to "string"
const hasType = "type" in result;
const hasComposition =
"oneOf" in result || "anyOf" in result || "allOf" in result || "$ref" in result;
const hasEnum = "enum" in result || "const" in result;
if (!hasType && !hasComposition && !hasEnum) {
result["type"] = "string";
}
return result;
}
/**
* Patch the `tools` array in a parsed request body, if present.
*/
function sanitizeRequestBody(body: Record<string, unknown>): Record<string, unknown> {
if (!Array.isArray(body.tools)) return body;
return {
...body,
tools: (body.tools as unknown[]).map((tool) => {
if (!tool || typeof tool !== "object") return tool;
const t = tool as Record<string, unknown>;
if (!t.function || typeof t.function !== "object") return t;
const fn = t.function as Record<string, unknown>;
if (!fn.parameters) return t;
return {
...t,
function: {
...fn,
parameters: sanitizeSchema(fn.parameters),
},
};
}),
};
}
// ---------------------------------------------------------------------------
// Process management
// ---------------------------------------------------------------------------
/**
* Kill any existing processes using the proxy port.
*/
function killExistingProxy(): void {
try {
// Use lsof to find processes on the port and kill them
const output = execSync(`lsof -ti:${PROXY_PORT} 2>/dev/null || true`, {
encoding: "utf-8",
});
const pids = output.trim().split("\n").filter(Boolean);
for (const pid of pids) {
try {
process.kill(Number(pid), "SIGTERM");
console.log(`[llama-proxy] Terminated old instance (PID: ${pid})`);
} catch {
// Process may have already exited
}
}
} catch {
// lsof not available or other error — continue anyway
}
}
// ---------------------------------------------------------------------------
// Proxy server
// ---------------------------------------------------------------------------
function startProxy(): http.Server {
const server = http.createServer((req, res) => {
const chunks: Buffer[] = [];
req.on("data", (chunk: Buffer) => chunks.push(chunk));
req.on("end", () => {
const rawBody = Buffer.concat(chunks).toString("utf-8");
// Attempt to sanitize schemas in JSON bodies
let forwardBody = rawBody;
const contentType = req.headers["content-type"] ?? "";
if (contentType.includes("application/json") && rawBody.trim().startsWith("{")) {
try {
const parsed = JSON.parse(rawBody) as Record<string, unknown>;
const sanitized = sanitizeRequestBody(parsed);
forwardBody = JSON.stringify(sanitized);
} catch {
// Not valid JSON — send as-is
}
}
const forwardBuffer = Buffer.from(forwardBody, "utf-8");
// Build forwarded headers, updating host and content-length
const forwardHeaders: Record<string, string | string[]> = {};
for (const [k, v] of Object.entries(req.headers)) {
if (k === "host") continue; // rewrite below
if (v !== undefined) forwardHeaders[k] = v as string | string[];
}
forwardHeaders["host"] = `${TARGET_HOST}:${TARGET_PORT}`;
forwardHeaders["content-length"] = String(forwardBuffer.byteLength);
const proxyReq = http.request(
{
host: TARGET_HOST,
port: TARGET_PORT,
path: req.url,
method: req.method,
headers: forwardHeaders,
},
(proxyRes) => {
res.writeHead(proxyRes.statusCode ?? 200, proxyRes.headers);
proxyRes.pipe(res, { end: true });
},
);
proxyReq.on("error", (err) => {
const msg = `Proxy error forwarding to llama-server: ${err.message}`;
if (!res.headersSent) {
res.writeHead(502, { "content-type": "text/plain" });
}
res.end(msg);
});
proxyReq.write(forwardBuffer);
proxyReq.end();
});
req.on("error", (err) => {
console.error("[llama-proxy] request error:", err);
});
});
server.listen(PROXY_PORT, "127.0.0.1", () => {
console.log(`[llama-proxy] Proxy started on port ${PROXY_PORT}`);
});
server.on("error", (err: NodeJS.ErrnoException) => {
if (err.code === "EADDRINUSE") {
console.error(
`[llama-proxy] Port ${PROXY_PORT} already in use. ` +
`Killing old instances and retrying...`,
);
killExistingProxy();
} else {
console.error("[llama-proxy] Server error:", err);
}
});
return server;
}
// ---------------------------------------------------------------------------
// Extension entry point
// ---------------------------------------------------------------------------
export default function (pi: ExtensionAPI) {
let server: http.Server | null = null;
let proxyEnabled = false;
/**
* Start the proxy and register the provider override.
*/
function enableProxy(): void {
if (proxyEnabled) {
console.log("[llama-proxy] Proxy already enabled");
return;
}
killExistingProxy();
server = startProxy();
// Override the llama-cpp provider's baseUrl to route through our proxy.
// models.json model definitions are preserved; only the endpoint changes.
pi.registerProvider("llama-cpp", {
baseUrl: `http://127.0.0.1:${PROXY_PORT}/v1`,
});
proxyEnabled = true;
console.log("[llama-proxy] Proxy enabled");
}
/**
* Disable the proxy and restore default provider.
*/
function disableProxy(): void {
if (!proxyEnabled) {
console.log("[llama-proxy] Proxy already disabled");
return;
}
if (server) {
server.close();
server = null;
}
// Reset provider to default (no baseUrl override)
pi.registerProvider("llama-cpp", {});
proxyEnabled = false;
console.log("[llama-proxy] Proxy disabled");
}
// Register the /llama-proxy command to toggle the proxy
pi.registerCommand("llama-proxy", async (args) => {
const action = args[0]?.toLowerCase() || "";
if (action === "on") {
enableProxy();
} else if (action === "off") {
disableProxy();
} else if (action === "status") {
console.log(`[llama-proxy] Status: ${proxyEnabled ? "enabled" : "disabled"}`);
} else {
// Toggle if no argument
if (proxyEnabled) {
disableProxy();
} else {
enableProxy();
}
}
});
// Clean up on session end
pi.on("session_end", async () => {
if (server) {
server.close();
}
});
}

View File

@@ -0,0 +1,178 @@
# LSP Extension
Language Server Protocol integration for pi-coding-agent.
## Highlights
- **Hook** (`lsp.ts`): Auto-diagnostics (default at agent end; optional per `write`/`edit`)
- **Tool** (`lsp-tool.ts`): On-demand LSP queries (definitions, references, hover, symbols, diagnostics, signatures)
- Manages one LSP server per project root and reuses them across turns
- **Efficient**: Bounded memory usage via LRU cache and idle file cleanup
- Supports TypeScript/JavaScript, Vue, Svelte, Dart/Flutter, Python, Go, Kotlin, Swift, and Rust
## Supported Languages
| Language | Server | Detection |
|----------|--------|-----------|
| TypeScript/JavaScript | `typescript-language-server` | `package.json`, `tsconfig.json` |
| Vue | `vue-language-server` | `package.json`, `vite.config.ts` |
| Svelte | `svelteserver` | `svelte.config.js` |
| Dart/Flutter | `dart language-server` | `pubspec.yaml` |
| Python | `pyright-langserver` | `pyproject.toml`, `requirements.txt` |
| Go | `gopls` | `go.mod` |
| Kotlin | `kotlin-ls` | `settings.gradle(.kts)`, `build.gradle(.kts)`, `pom.xml` |
| Swift | `sourcekit-lsp` | `Package.swift`, Xcode (`*.xcodeproj` / `*.xcworkspace`) |
| Rust | `rust-analyzer` | `Cargo.toml` |
### Known Limitations
**rust-analyzer**: Very slow to initialize (30-60+ seconds) because it compiles the entire Rust project before returning diagnostics. This is a known rust-analyzer behavior, not a bug in this extension. For quick feedback, consider using `cargo check` directly.
## Usage
### Installation
Install the package and enable extensions:
```bash
pi install npm:lsp-pi
pi config
```
Dependencies are installed automatically during `pi install`.
### Prerequisites
Install the language servers you need:
```bash
# TypeScript/JavaScript
npm i -g typescript-language-server typescript
# Vue
npm i -g @vue/language-server
# Svelte
npm i -g svelte-language-server
# Python
npm i -g pyright
# Go (install gopls via go install)
go install golang.org/x/tools/gopls@latest
# Kotlin (kotlin-ls)
brew install JetBrains/utils/kotlin-lsp
# Swift (sourcekit-lsp; macOS)
# Usually available via Xcode / Command Line Tools
xcrun sourcekit-lsp --help
# Rust (install via rustup)
rustup component add rust-analyzer
```
The extension spawns binaries from your PATH.
## How It Works
### Hook (auto-diagnostics)
1. On `session_start`, warms up LSP for detected project type
2. Tracks files touched by `write`/`edit`
3. Default (`agent_end`): at agent end, sends touched files to LSP and posts a diagnostics message
4. Optional (`edit_write`): per `write`/`edit`, appends diagnostics to the tool result
5. Shows notification with diagnostic summary
6. **Memory Management**: Keeps up to 30 files open per LSP server (LRU eviction), automatically closes idle files (> 60s), and shuts down all LSP servers after 2 minutes of post-agent inactivity (servers restart lazily when files are read again).
7. **Robustness**: Reuses cached diagnostics if a server doesn't re-publish them for unchanged files, avoiding false timeouts on re-analysis.
### Tool (on-demand queries)
The `lsp` tool provides these actions:
| Action | Description | Requires |
|--------|-------------|----------|
| `definition` | Jump to definition | `file` + (`line`/`column` or `query`) |
| `references` | Find all references | `file` + (`line`/`column` or `query`) |
| `hover` | Get type/docs info | `file` + (`line`/`column` or `query`) |
| `symbols` | List symbols in file | `file`, optional `query` filter |
| `diagnostics` | Get single file diagnostics | `file`, optional `severity` filter |
| `workspace-diagnostics` | Get diagnostics for multiple files | `files` array, optional `severity` filter |
| `signature` | Get function signature | `file` + (`line`/`column` or `query`) |
| `rename` | Rename symbol across files | `file` + (`line`/`column` or `query`) + `newName` |
| `codeAction` | Get available quick fixes/refactors | `file` + `line`/`column`, optional `endLine`/`endColumn` |
**Query resolution**: For position-based actions, you can provide a `query` (symbol name) instead of `line`/`column`. The tool will find the symbol in the file and use its position.
**Severity filtering**: For `diagnostics` and `workspace-diagnostics` actions, use the `severity` parameter to filter results:
- `all` (default): Show all diagnostics
- `error`: Only errors
- `warning`: Errors and warnings
- `info`: Errors, warnings, and info
- `hint`: All including hints
**Workspace diagnostics**: The `workspace-diagnostics` action analyzes multiple files at once. Pass an array of file paths in the `files` parameter. Each file will be opened, analyzed by the appropriate LSP server, and diagnostics returned. Files are cleaned up after analysis to prevent memory bloat.
```bash
# Find all TypeScript files and check for errors
find src -name "*.ts" -type f | xargs ...
# Example tool call
lsp action=workspace-diagnostics files=["src/index.ts", "src/utils.ts"] severity=error
```
Example questions the LLM can answer using this tool:
- "Where is `handleSessionStart` defined in `lsp-hook.ts`?"
- "Find all references to `getManager`"
- "What type does `getDefinition` return?"
- "List symbols in `lsp-core.ts`"
- "Check all TypeScript files in src/ for errors"
- "Get only errors from `index.ts`"
- "Rename `oldFunction` to `newFunction`"
- "What quick fixes are available at line 10?"
## Settings
Use `/lsp` to configure the auto diagnostics hook:
- Mode: default at agent end; can run after each edit/write or be disabled
- Scope: session-only or global (`~/.pi/agent/settings.json`)
To disable auto diagnostics, choose "Disabled" in `/lsp` or set in `~/.pi/agent/settings.json`:
```json
{
"lsp": {
"hookMode": "disabled"
}
}
```
Other values: `"agent_end"` (default) and `"edit_write"`.
Agent-end mode analyzes files touched during the full agent response (after all tool calls complete) and posts a diagnostics message only once. Disabling the hook does not disable the `/lsp` tool.
## File Structure
| File | Purpose |
|------|---------|
| `lsp.ts` | Hook extension (auto-diagnostics; default at agent end) |
| `lsp-tool.ts` | Tool extension (on-demand LSP queries) |
| `lsp-core.ts` | LSPManager class, server configs, singleton manager |
| `package.json` | Declares both extensions via "pi" field |
## Testing
```bash
# Unit tests (root detection, configuration)
npm test
# Tool tests
npm run test:tool
# Integration tests (spawns real language servers)
npm run test:integration
# Run rust-analyzer tests (slow, disabled by default)
RUST_LSP_TEST=1 npm run test:integration
```
## License
MIT

View File

@@ -0,0 +1,12 @@
/**
* Combined lsp-pi extension entry point.
* Loads both the hook extension (lsp.ts) and the tool extension (lsp-tool.ts).
*/
import type { ExtensionAPI } from "@mariozechner/pi-coding-agent";
import lspHook from "./lsp.js";
import lspTool from "./lsp-tool.js";
export default function (pi: ExtensionAPI) {
lspHook(pi);
lspTool(pi);
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,382 @@
/**
* LSP Tool Extension for pi-coding-agent
*
* Provides Language Server Protocol tool for:
* - definitions, references, hover, signature help
* - document symbols, diagnostics, workspace diagnostics
* - rename, code actions
*
* Supported languages:
* - Dart/Flutter (dart language-server)
* - TypeScript/JavaScript (typescript-language-server)
* - Vue (vue-language-server)
* - Svelte (svelteserver)
* - Python (pyright-langserver)
* - Go (gopls)
* - Kotlin (kotlin-ls)
* - Swift (sourcekit-lsp)
* - Rust (rust-analyzer)
*
* Usage:
* pi --extension ./lsp-tool.ts
*
* Or use the combined lsp.ts extension for both hook and tool functionality.
*/
import * as path from "node:path";
import { Type, type Static } from "@sinclair/typebox";
import { StringEnum } from "@mariozechner/pi-ai";
import type { ExtensionAPI } from "@mariozechner/pi-coding-agent";
import { Text } from "@mariozechner/pi-tui";
import { getOrCreateManager, formatDiagnostic, filterDiagnosticsBySeverity, uriToPath, resolvePosition, type SeverityFilter } from "./lsp-core.js";
const PREVIEW_LINES = 10;
const DIAGNOSTICS_WAIT_MS_DEFAULT = 3000;
function diagnosticsWaitMsForFile(filePath: string): number {
const ext = path.extname(filePath).toLowerCase();
if (ext === ".kt" || ext === ".kts") return 30000;
if (ext === ".swift") return 20000;
if (ext === ".rs") return 20000;
return DIAGNOSTICS_WAIT_MS_DEFAULT;
}
const ACTIONS = ["definition", "references", "hover", "symbols", "diagnostics", "workspace-diagnostics", "signature", "rename", "codeAction"] as const;
const SEVERITY_FILTERS = ["all", "error", "warning", "info", "hint"] as const;
const LspParams = Type.Object({
action: StringEnum(ACTIONS),
file: Type.Optional(Type.String({ description: "File path (required for most actions)" })),
files: Type.Optional(Type.Array(Type.String(), { description: "File paths for workspace-diagnostics" })),
line: Type.Optional(Type.Number({ description: "Line (1-indexed). Required for position-based actions unless query provided." })),
column: Type.Optional(Type.Number({ description: "Column (1-indexed). Required for position-based actions unless query provided." })),
endLine: Type.Optional(Type.Number({ description: "End line for range-based actions (codeAction)" })),
endColumn: Type.Optional(Type.Number({ description: "End column for range-based actions (codeAction)" })),
query: Type.Optional(Type.String({ description: "Symbol name filter (for symbols) or to resolve position (for definition/references/hover/signature)" })),
newName: Type.Optional(Type.String({ description: "New name for rename action" })),
severity: Type.Optional(StringEnum(SEVERITY_FILTERS, { description: 'Filter diagnostics: "all"|"error"|"warning"|"info"|"hint"' })),
});
type LspParamsType = Static<typeof LspParams>;
function abortable<T>(promise: Promise<T>, signal?: AbortSignal): Promise<T> {
if (!signal) return promise;
if (signal.aborted) return Promise.reject(new Error("aborted"));
return new Promise<T>((resolve, reject) => {
const onAbort = () => {
cleanup();
reject(new Error("aborted"));
};
const cleanup = () => {
signal.removeEventListener("abort", onAbort);
};
signal.addEventListener("abort", onAbort, { once: true });
promise.then(
(value) => {
cleanup();
resolve(value);
},
(err) => {
cleanup();
reject(err);
},
);
});
}
function isAbortedError(e: unknown): boolean {
return e instanceof Error && e.message === "aborted";
}
function cancelledToolResult() {
return {
content: [{ type: "text" as const, text: "Cancelled" }],
details: { cancelled: true },
};
}
type ExecuteArgs = {
signal: AbortSignal | undefined;
onUpdate: ((update: { content: Array<{ type: "text"; text: string }>; details?: Record<string, unknown> }) => void) | undefined;
ctx: { cwd: string };
};
function isAbortSignalLike(value: unknown): value is AbortSignal {
return !!value
&& typeof value === "object"
&& "aborted" in value
&& typeof (value as any).aborted === "boolean"
&& typeof (value as any).addEventListener === "function";
}
function isContextLike(value: unknown): value is { cwd: string } {
return !!value && typeof value === "object" && typeof (value as any).cwd === "string";
}
function normalizeExecuteArgs(onUpdateArg: unknown, ctxArg: unknown, signalArg: unknown): ExecuteArgs {
// Runtime >= 0.51: (signal, onUpdate, ctx)
if (isContextLike(signalArg)) {
return {
signal: isAbortSignalLike(onUpdateArg) ? onUpdateArg : undefined,
onUpdate: typeof ctxArg === "function" ? ctxArg as ExecuteArgs["onUpdate"] : undefined,
ctx: signalArg,
};
}
// Runtime <= 0.50: (onUpdate, ctx, signal)
if (isContextLike(ctxArg)) {
return {
signal: isAbortSignalLike(signalArg) ? signalArg : undefined,
onUpdate: typeof onUpdateArg === "function" ? onUpdateArg as ExecuteArgs["onUpdate"] : undefined,
ctx: ctxArg,
};
}
throw new Error("Invalid tool execution context");
}
function formatLocation(loc: { uri: string; range?: { start?: { line: number; character: number } } }, cwd?: string): string {
const abs = uriToPath(loc.uri);
const display = cwd && path.isAbsolute(abs) ? path.relative(cwd, abs) : abs;
const { line, character: col } = loc.range?.start ?? {};
return typeof line === "number" && typeof col === "number" ? `${display}:${line + 1}:${col + 1}` : display;
}
function formatHover(contents: unknown): string {
if (typeof contents === "string") return contents;
if (Array.isArray(contents)) return contents.map(c => typeof c === "string" ? c : (c as any)?.value ?? "").filter(Boolean).join("\n\n");
if (contents && typeof contents === "object" && "value" in contents) return String((contents as any).value);
return "";
}
function formatSignature(help: any): string {
if (!help?.signatures?.length) return "No signature help available.";
const sig = help.signatures[help.activeSignature ?? 0] ?? help.signatures[0];
let text = sig.label ?? "Signature";
if (sig.documentation) text += `\n${typeof sig.documentation === "string" ? sig.documentation : sig.documentation?.value ?? ""}`;
if (sig.parameters?.length) {
const params = sig.parameters.map((p: any) => typeof p.label === "string" ? p.label : Array.isArray(p.label) ? p.label.join("-") : "").filter(Boolean);
if (params.length) text += `\nParameters: ${params.join(", ")}`;
}
return text;
}
function collectSymbols(symbols: any[], depth = 0, lines: string[] = [], query?: string): string[] {
for (const sym of symbols) {
const name = sym?.name ?? "<unknown>";
if (query && !name.toLowerCase().includes(query.toLowerCase())) {
if (sym.children?.length) collectSymbols(sym.children, depth + 1, lines, query);
continue;
}
const loc = sym?.range?.start ? `${sym.range.start.line + 1}:${sym.range.start.character + 1}` : "";
lines.push(`${" ".repeat(depth)}${name}${loc ? ` (${loc})` : ""}`);
if (sym.children?.length) collectSymbols(sym.children, depth + 1, lines, query);
}
return lines;
}
function formatWorkspaceEdit(edit: any, cwd?: string): string {
const lines: string[] = [];
if (edit.documentChanges?.length) {
for (const change of edit.documentChanges) {
if (change.textDocument?.uri) {
const fp = uriToPath(change.textDocument.uri);
const display = cwd && path.isAbsolute(fp) ? path.relative(cwd, fp) : fp;
lines.push(`${display}:`);
for (const e of change.edits || []) {
const loc = `${e.range.start.line + 1}:${e.range.start.character + 1}`;
lines.push(` [${loc}] → "${e.newText}"`);
}
}
}
}
if (edit.changes) {
for (const [uri, edits] of Object.entries(edit.changes)) {
const fp = uriToPath(uri);
const display = cwd && path.isAbsolute(fp) ? path.relative(cwd, fp) : fp;
lines.push(`${display}:`);
for (const e of edits as any[]) {
const loc = `${e.range.start.line + 1}:${e.range.start.character + 1}`;
lines.push(` [${loc}] → "${e.newText}"`);
}
}
}
return lines.length ? lines.join("\n") : "No edits.";
}
function formatCodeActions(actions: any[]): string[] {
return actions.map((a, i) => {
const title = a.title || a.command?.title || "Untitled action";
const kind = a.kind ? ` (${a.kind})` : "";
const isPreferred = a.isPreferred ? " ★" : "";
return `${i + 1}. ${title}${kind}${isPreferred}`;
});
}
export default function (pi: ExtensionAPI) {
pi.registerTool({
name: "lsp",
label: "LSP",
description: `Query language server for definitions, references, types, symbols, diagnostics, rename, and code actions.
Actions: definition, references, hover, signature, rename (require file + line/column or query), symbols (file, optional query), diagnostics (file), workspace-diagnostics (files array), codeAction (file + position).
Use bash to find files: find src -name "*.ts" -type f`,
parameters: LspParams,
async execute(_toolCallId, params, onUpdateArg, ctxArg, signalArg) {
const { signal, onUpdate, ctx } = normalizeExecuteArgs(onUpdateArg, ctxArg, signalArg);
if (signal?.aborted) return cancelledToolResult();
if (onUpdate) {
onUpdate({ content: [{ type: "text", text: "Working..." }], details: { status: "working" } });
}
const manager = getOrCreateManager(ctx.cwd);
const { action, file, files, line, column, endLine, endColumn, query, newName, severity } = params as LspParamsType;
const sevFilter: SeverityFilter = severity || "all";
const needsFile = action !== "workspace-diagnostics";
const needsPos = ["definition", "references", "hover", "signature", "rename", "codeAction"].includes(action);
try {
if (needsFile && !file) throw new Error(`Action "${action}" requires a file path.`);
let rLine = line, rCol = column, fromQuery = false;
if (needsPos && (rLine === undefined || rCol === undefined) && query && file) {
const resolved = await abortable(resolvePosition(manager, file, query), signal);
if (resolved) { rLine = resolved.line; rCol = resolved.column; fromQuery = true; }
}
if (needsPos && (rLine === undefined || rCol === undefined)) {
throw new Error(`Action "${action}" requires line/column or a query matching a symbol.`);
}
const qLine = query ? `query: ${query}\n` : "";
const sevLine = sevFilter !== "all" ? `severity: ${sevFilter}\n` : "";
const posLine = fromQuery && rLine && rCol ? `resolvedPosition: ${rLine}:${rCol}\n` : "";
switch (action) {
case "definition": {
const results = await abortable(manager.getDefinition(file!, rLine!, rCol!), signal);
const locs = results.map(l => formatLocation(l, ctx?.cwd));
const payload = locs.length ? locs.join("\n") : fromQuery ? `${file}:${rLine}:${rCol}` : "No definitions found.";
return { content: [{ type: "text", text: `action: definition\n${qLine}${posLine}${payload}` }], details: results };
}
case "references": {
const results = await abortable(manager.getReferences(file!, rLine!, rCol!), signal);
const locs = results.map(l => formatLocation(l, ctx?.cwd));
return { content: [{ type: "text", text: `action: references\n${qLine}${posLine}${locs.length ? locs.join("\n") : "No references found."}` }], details: results };
}
case "hover": {
const result = await abortable(manager.getHover(file!, rLine!, rCol!), signal);
const payload = result ? formatHover(result.contents) || "No hover information." : "No hover information.";
return { content: [{ type: "text", text: `action: hover\n${qLine}${posLine}${payload}` }], details: result ?? null };
}
case "symbols": {
const symbols = await abortable(manager.getDocumentSymbols(file!), signal);
const lines = collectSymbols(symbols, 0, [], query);
const payload = lines.length ? lines.join("\n") : query ? `No symbols matching "${query}".` : "No symbols found.";
return { content: [{ type: "text", text: `action: symbols\n${qLine}${payload}` }], details: symbols };
}
case "diagnostics": {
const result = await abortable(manager.touchFileAndWait(file!, diagnosticsWaitMsForFile(file!)), signal);
const filtered = filterDiagnosticsBySeverity(result.diagnostics, sevFilter);
const payload = (result as any).unsupported
? `Unsupported: ${(result as any).error || "No LSP for this file."}`
: !result.receivedResponse
? "Timeout: LSP server did not respond. Try again."
: filtered.length ? filtered.map(formatDiagnostic).join("\n") : "No diagnostics.";
return { content: [{ type: "text", text: `action: diagnostics\n${sevLine}${payload}` }], details: { ...result, diagnostics: filtered } };
}
case "workspace-diagnostics": {
if (!files?.length) throw new Error('Action "workspace-diagnostics" requires a "files" array.');
const waitMs = Math.max(...files.map(diagnosticsWaitMsForFile));
const result = await abortable(manager.getDiagnosticsForFiles(files, waitMs), signal);
const out: string[] = [];
let errors = 0, warnings = 0, filesWithIssues = 0;
for (const item of result.items) {
const display = ctx?.cwd && path.isAbsolute(item.file) ? path.relative(ctx.cwd, item.file) : item.file;
if (item.status !== 'ok') { out.push(`${display}: ${item.error || item.status}`); continue; }
const filtered = filterDiagnosticsBySeverity(item.diagnostics, sevFilter);
if (filtered.length) {
filesWithIssues++;
out.push(`${display}:`);
for (const d of filtered) {
if (d.severity === 1) errors++; else if (d.severity === 2) warnings++;
out.push(` ${formatDiagnostic(d)}`);
}
}
}
const summary = `Analyzed ${result.items.length} file(s): ${errors} error(s), ${warnings} warning(s) in ${filesWithIssues} file(s)`;
return { content: [{ type: "text", text: `action: workspace-diagnostics\n${sevLine}${summary}\n\n${out.length ? out.join("\n") : "No diagnostics."}` }], details: result };
}
case "signature": {
const result = await abortable(manager.getSignatureHelp(file!, rLine!, rCol!), signal);
return { content: [{ type: "text", text: `action: signature\n${qLine}${posLine}${formatSignature(result)}` }], details: result ?? null };
}
case "rename": {
if (!newName) throw new Error('Action "rename" requires a "newName" parameter.');
const result = await abortable(manager.rename(file!, rLine!, rCol!, newName), signal);
if (!result) return { content: [{ type: "text", text: `action: rename\n${qLine}${posLine}No rename available at this position.` }], details: null };
const edits = formatWorkspaceEdit(result, ctx?.cwd);
return { content: [{ type: "text", text: `action: rename\n${qLine}${posLine}newName: ${newName}\n\n${edits}` }], details: result };
}
case "codeAction": {
const result = await abortable(manager.getCodeActions(file!, rLine!, rCol!, endLine, endColumn), signal);
const actions = formatCodeActions(result);
return { content: [{ type: "text", text: `action: codeAction\n${qLine}${posLine}${actions.length ? actions.join("\n") : "No code actions available."}` }], details: result };
}
}
} catch (e) {
if (signal?.aborted || isAbortedError(e)) return cancelledToolResult();
throw e;
}
},
renderCall(args, theme) {
const params = args as LspParamsType;
let text = theme.fg("toolTitle", theme.bold("lsp ")) + theme.fg("accent", params.action || "...");
if (params.file) text += " " + theme.fg("muted", params.file);
else if (params.files?.length) text += " " + theme.fg("muted", `${params.files.length} file(s)`);
if (params.query) text += " " + theme.fg("dim", `query="${params.query}"`);
else if (params.line !== undefined && params.column !== undefined) text += theme.fg("warning", `:${params.line}:${params.column}`);
if (params.severity && params.severity !== "all") text += " " + theme.fg("dim", `[${params.severity}]`);
return new Text(text, 0, 0);
},
renderResult(result, options, theme) {
if (options.isPartial) return new Text(theme.fg("warning", "Working..."), 0, 0);
const textContent = (result.content?.find((c: any) => c.type === "text") as any)?.text || "";
const lines = textContent.split("\n");
let headerEnd = 0;
for (let i = 0; i < lines.length; i++) {
if (/^(action|query|severity|resolvedPosition):/.test(lines[i])) headerEnd = i + 1;
else break;
}
const header = lines.slice(0, headerEnd);
const content = lines.slice(headerEnd);
const maxLines = options.expanded ? content.length : PREVIEW_LINES;
const display = content.slice(0, maxLines);
const remaining = content.length - maxLines;
let out = header.map((l: string) => theme.fg("muted", l)).join("\n");
if (display.length) {
if (out) out += "\n";
out += display.map((l: string) => theme.fg("toolOutput", l)).join("\n");
}
if (remaining > 0) out += theme.fg("dim", `\n... (${remaining} more lines)`);
return new Text(out, 0, 0);
},
});
}

View File

@@ -0,0 +1,604 @@
/**
* LSP Hook Extension for pi-coding-agent
*
* Provides automatic diagnostics feedback (default: agent end).
* Can run after each write/edit or once per agent response.
*
* Usage:
* pi --extension ./lsp.ts
*
* Or load the directory to get both hook and tool:
* pi --extension ./lsp/
*/
import * as path from "node:path";
import * as fs from "node:fs";
import * as os from "node:os";
import { type ExtensionAPI, type ExtensionContext } from "@mariozechner/pi-coding-agent";
import { Text } from "@mariozechner/pi-tui";
import { type Diagnostic } from "vscode-languageserver-protocol";
import { LSP_SERVERS, formatDiagnostic, getOrCreateManager, shutdownManager } from "./lsp-core.js";
type HookScope = "session" | "global";
type HookMode = "edit_write" | "agent_end" | "disabled";
const DIAGNOSTICS_WAIT_MS_DEFAULT = 3000;
function diagnosticsWaitMsForFile(filePath: string): number {
const ext = path.extname(filePath).toLowerCase();
if (ext === ".kt" || ext === ".kts") return 30000;
if (ext === ".swift") return 20000;
if (ext === ".rs") return 20000;
return DIAGNOSTICS_WAIT_MS_DEFAULT;
}
const DIAGNOSTICS_PREVIEW_LINES = 10;
const LSP_IDLE_SHUTDOWN_MS = 2 * 60 * 1000;
const DIM = "\x1b[2m", GREEN = "\x1b[32m", YELLOW = "\x1b[33m", RESET = "\x1b[0m";
const DEFAULT_HOOK_MODE: HookMode = "agent_end";
const SETTINGS_NAMESPACE = "lsp";
const LSP_CONFIG_ENTRY = "lsp-hook-config";
const WARMUP_MAP: Record<string, string> = {
"pubspec.yaml": ".dart",
"package.json": ".ts",
"pyproject.toml": ".py",
"go.mod": ".go",
"Cargo.toml": ".rs",
"settings.gradle": ".kt",
"settings.gradle.kts": ".kt",
"build.gradle": ".kt",
"build.gradle.kts": ".kt",
"pom.xml": ".kt",
"gradlew": ".kt",
"gradle.properties": ".kt",
"Package.swift": ".swift",
};
const MODE_LABELS: Record<HookMode, string> = {
edit_write: "After each edit/write",
agent_end: "At agent end",
disabled: "Disabled",
};
function normalizeHookMode(value: unknown): HookMode | undefined {
if (value === "edit_write" || value === "agent_end" || value === "disabled") return value;
if (value === "turn_end") return "agent_end";
return undefined;
}
interface HookConfigEntry {
scope: HookScope;
hookMode?: HookMode;
}
export default function (pi: ExtensionAPI) {
type LspActivity = "idle" | "loading" | "working";
let activeClients: Set<string> = new Set();
let statusUpdateFn: ((key: string, text: string | undefined) => void) | null = null;
let hookMode: HookMode = DEFAULT_HOOK_MODE;
let hookScope: HookScope = "global";
let activity: LspActivity = "idle";
let diagnosticsAbort: AbortController | null = null;
let shuttingDown = false;
let idleShutdownTimer: NodeJS.Timeout | null = null;
const touchedFiles: Map<string, boolean> = new Map();
const globalSettingsPath = path.join(os.homedir(), ".pi", "agent", "settings.json");
function readSettingsFile(filePath: string): Record<string, unknown> {
try {
if (!fs.existsSync(filePath)) return {};
const raw = fs.readFileSync(filePath, "utf-8");
const parsed = JSON.parse(raw);
return parsed && typeof parsed === "object" ? parsed as Record<string, unknown> : {};
} catch {
return {};
}
}
function getGlobalHookMode(): HookMode | undefined {
const settings = readSettingsFile(globalSettingsPath);
const lspSettings = settings[SETTINGS_NAMESPACE];
const hookValue = (lspSettings as { hookMode?: unknown; hookEnabled?: unknown } | undefined)?.hookMode;
const normalized = normalizeHookMode(hookValue);
if (normalized) return normalized;
const legacyEnabled = (lspSettings as { hookEnabled?: unknown } | undefined)?.hookEnabled;
if (typeof legacyEnabled === "boolean") return legacyEnabled ? "edit_write" : "disabled";
return undefined;
}
function setGlobalHookMode(mode: HookMode): boolean {
try {
const settings = readSettingsFile(globalSettingsPath);
const existing = settings[SETTINGS_NAMESPACE];
const nextNamespace = (existing && typeof existing === "object")
? { ...(existing as Record<string, unknown>), hookMode: mode }
: { hookMode: mode };
settings[SETTINGS_NAMESPACE] = nextNamespace;
fs.mkdirSync(path.dirname(globalSettingsPath), { recursive: true });
fs.writeFileSync(globalSettingsPath, JSON.stringify(settings, null, 2), "utf-8");
return true;
} catch {
return false;
}
}
function getLastHookEntry(ctx: ExtensionContext): HookConfigEntry | undefined {
const branchEntries = ctx.sessionManager.getBranch();
let latest: HookConfigEntry | undefined;
for (const entry of branchEntries) {
if (entry.type === "custom" && entry.customType === LSP_CONFIG_ENTRY) {
latest = entry.data as HookConfigEntry | undefined;
}
}
return latest;
}
function restoreHookState(ctx: ExtensionContext): void {
const entry = getLastHookEntry(ctx);
if (entry?.scope === "session") {
const normalized = normalizeHookMode(entry.hookMode);
if (normalized) {
hookMode = normalized;
hookScope = "session";
return;
}
const legacyEnabled = (entry as { hookEnabled?: unknown }).hookEnabled;
if (typeof legacyEnabled === "boolean") {
hookMode = legacyEnabled ? "edit_write" : "disabled";
hookScope = "session";
return;
}
}
const globalSetting = getGlobalHookMode();
hookMode = globalSetting ?? DEFAULT_HOOK_MODE;
hookScope = "global";
}
function persistHookEntry(entry: HookConfigEntry): void {
pi.appendEntry<HookConfigEntry>(LSP_CONFIG_ENTRY, entry);
}
function labelForMode(mode: HookMode): string {
return MODE_LABELS[mode];
}
function messageContentToText(content: unknown): string {
if (typeof content === "string") return content;
if (Array.isArray(content)) {
return content
.map((item) => (item && typeof item === "object" && "type" in item && (item as any).type === "text")
? String((item as any).text ?? "")
: "")
.filter(Boolean)
.join("\n");
}
return "";
}
function formatDiagnosticsForDisplay(text: string): string {
return text
.replace(/\n?This file has errors, please fix\n/gi, "\n")
.replace(/<\/?file_diagnostics>\n?/gi, "")
.replace(/\n{3,}/g, "\n\n")
.trim();
}
function setActivity(next: LspActivity): void {
activity = next;
updateLspStatus();
}
function clearIdleShutdownTimer(): void {
if (!idleShutdownTimer) return;
clearTimeout(idleShutdownTimer);
idleShutdownTimer = null;
}
async function shutdownLspServersForIdle(): Promise<void> {
diagnosticsAbort?.abort();
diagnosticsAbort = null;
setActivity("idle");
await shutdownManager();
activeClients.clear();
updateLspStatus();
}
function scheduleIdleShutdown(): void {
clearIdleShutdownTimer();
idleShutdownTimer = setTimeout(() => {
idleShutdownTimer = null;
if (shuttingDown) return;
void shutdownLspServersForIdle();
}, LSP_IDLE_SHUTDOWN_MS);
(idleShutdownTimer as any).unref?.();
}
function updateLspStatus(): void {
if (!statusUpdateFn) return;
const clients = activeClients.size > 0 ? [...activeClients].join(", ") : "";
const clientsText = clients ? `${DIM}${clients}${RESET}` : "";
const activityHint = activity === "idle" ? "" : `${DIM}${RESET}`;
if (hookMode === "disabled") {
const text = clientsText
? `${YELLOW}LSP${RESET} ${DIM}(tool)${RESET}: ${clientsText}`
: `${YELLOW}LSP${RESET} ${DIM}(tool)${RESET}`;
statusUpdateFn("lsp", text);
return;
}
let text = `${GREEN}LSP${RESET}`;
if (activityHint) text += ` ${activityHint}`;
if (clientsText) text += ` ${clientsText}`;
statusUpdateFn("lsp", text);
}
function normalizeFilePath(filePath: string, cwd: string): string {
return path.isAbsolute(filePath) ? filePath : path.resolve(cwd, filePath);
}
pi.registerMessageRenderer("lsp-diagnostics", (message, options, theme) => {
const content = formatDiagnosticsForDisplay(messageContentToText(message.content));
if (!content) return new Text("", 0, 0);
const expanded = options.expanded === true;
const lines = content.split("\n");
const maxLines = expanded ? lines.length : DIAGNOSTICS_PREVIEW_LINES;
const display = lines.slice(0, maxLines);
const remaining = lines.length - display.length;
const styledLines = display.map((line) => {
if (line.startsWith("File: ")) return theme.fg("muted", line);
return theme.fg("toolOutput", line);
});
if (!expanded && remaining > 0) {
styledLines.push(theme.fg("dim", `... (${remaining} more lines)`));
}
return new Text(styledLines.join("\n"), 0, 0);
});
function getServerConfig(filePath: string) {
const ext = path.extname(filePath);
return LSP_SERVERS.find((s) => s.extensions.includes(ext));
}
function ensureActiveClientForFile(filePath: string, cwd: string): string | undefined {
const absPath = normalizeFilePath(filePath, cwd);
const cfg = getServerConfig(absPath);
if (!cfg) return undefined;
if (!activeClients.has(cfg.id)) {
activeClients.add(cfg.id);
updateLspStatus();
}
return absPath;
}
function extractLspFiles(input: Record<string, unknown>): string[] {
const files: string[] = [];
if (typeof input.file === "string") files.push(input.file);
if (Array.isArray(input.files)) {
for (const item of input.files) {
if (typeof item === "string") files.push(item);
}
}
return files;
}
function buildDiagnosticsOutput(
filePath: string,
diagnostics: Diagnostic[],
cwd: string,
includeFileHeader: boolean,
): { notification: string; errorCount: number; output: string } {
const absPath = path.isAbsolute(filePath) ? filePath : path.resolve(cwd, filePath);
const relativePath = path.relative(cwd, absPath);
const errorCount = diagnostics.filter((e) => e.severity === 1).length;
const MAX = 5;
const lines = diagnostics.slice(0, MAX).map((e) => {
const sev = e.severity === 1 ? "ERROR" : "WARN";
return `${sev}[${e.range.start.line + 1}] ${e.message.split("\n")[0]}`;
});
let notification = `📋 ${relativePath}\n${lines.join("\n")}`;
if (diagnostics.length > MAX) notification += `\n... +${diagnostics.length - MAX} more`;
const header = includeFileHeader ? `File: ${relativePath}\n` : "";
const output = `\n${header}This file has errors, please fix\n<file_diagnostics>\n${diagnostics.map(formatDiagnostic).join("\n")}\n</file_diagnostics>\n`;
return { notification, errorCount, output };
}
async function collectDiagnostics(
filePath: string,
ctx: ExtensionContext,
includeWarnings: boolean,
includeFileHeader: boolean,
notify = true,
): Promise<string | undefined> {
const manager = getOrCreateManager(ctx.cwd);
const absPath = ensureActiveClientForFile(filePath, ctx.cwd);
if (!absPath) return undefined;
try {
const result = await manager.touchFileAndWait(absPath, diagnosticsWaitMsForFile(absPath));
if (!result.receivedResponse) return undefined;
const diagnostics = includeWarnings
? result.diagnostics
: result.diagnostics.filter((d) => d.severity === 1);
if (!diagnostics.length) return undefined;
const report = buildDiagnosticsOutput(filePath, diagnostics, ctx.cwd, includeFileHeader);
if (notify) {
if (ctx.hasUI) ctx.ui.notify(report.notification, report.errorCount > 0 ? "error" : "warning");
else console.error(report.notification);
}
return report.output;
} catch {
return undefined;
}
}
pi.registerCommand("lsp", {
description: "LSP settings (auto diagnostics hook)",
handler: async (_args, ctx) => {
if (!ctx.hasUI) {
ctx.ui.notify("LSP settings require UI", "warning");
return;
}
const currentMark = " ✓";
const modeOptions = ([
"edit_write",
"agent_end",
"disabled",
] as HookMode[]).map((mode) => ({
mode,
label: mode === hookMode ? `${labelForMode(mode)}${currentMark}` : labelForMode(mode),
}));
const modeChoice = await ctx.ui.select(
"LSP auto diagnostics hook mode:",
modeOptions.map((option) => option.label),
);
if (!modeChoice) return;
const nextMode = modeOptions.find((option) => option.label === modeChoice)?.mode;
if (!nextMode) return;
const scopeOptions = [
{
scope: "session" as HookScope,
label: "Session only",
},
{
scope: "global" as HookScope,
label: "Global (all sessions)",
},
];
const scopeChoice = await ctx.ui.select(
"Apply LSP auto diagnostics hook setting to:",
scopeOptions.map((option) => option.label),
);
if (!scopeChoice) return;
const scope = scopeOptions.find((option) => option.label === scopeChoice)?.scope;
if (!scope) return;
if (scope === "global") {
const ok = setGlobalHookMode(nextMode);
if (!ok) {
ctx.ui.notify("Failed to update global settings", "error");
return;
}
}
hookMode = nextMode;
hookScope = scope;
touchedFiles.clear();
persistHookEntry({ scope, hookMode: nextMode });
updateLspStatus();
ctx.ui.notify(`LSP hook: ${labelForMode(hookMode)} (${hookScope})`, "info");
},
});
pi.on("session_start", async (_event, ctx) => {
restoreHookState(ctx);
statusUpdateFn = ctx.hasUI && ctx.ui.setStatus ? ctx.ui.setStatus.bind(ctx.ui) : null;
updateLspStatus();
if (hookMode === "disabled") return;
const manager = getOrCreateManager(ctx.cwd);
for (const [marker, ext] of Object.entries(WARMUP_MAP)) {
if (fs.existsSync(path.join(ctx.cwd, marker))) {
setActivity("loading");
manager.getClientsForFile(path.join(ctx.cwd, `dummy${ext}`))
.then((clients) => {
if (clients.length > 0) {
const cfg = LSP_SERVERS.find((s) => s.extensions.includes(ext));
if (cfg) activeClients.add(cfg.id);
}
})
.catch(() => {})
.finally(() => setActivity("idle"));
break;
}
}
});
pi.on("session_switch", async (_event, ctx) => {
restoreHookState(ctx);
updateLspStatus();
});
pi.on("session_tree", async (_event, ctx) => {
restoreHookState(ctx);
updateLspStatus();
});
pi.on("session_fork", async (_event, ctx) => {
restoreHookState(ctx);
updateLspStatus();
});
pi.on("session_shutdown", async () => {
shuttingDown = true;
clearIdleShutdownTimer();
diagnosticsAbort?.abort();
diagnosticsAbort = null;
setActivity("idle");
await shutdownManager();
activeClients.clear();
statusUpdateFn?.("lsp", undefined);
});
pi.on("tool_call", async (event, ctx) => {
const input = (event.input && typeof event.input === "object")
? event.input as Record<string, unknown>
: {};
if (event.toolName === "lsp") {
clearIdleShutdownTimer();
const files = extractLspFiles(input);
for (const file of files) {
ensureActiveClientForFile(file, ctx.cwd);
}
return;
}
if (event.toolName !== "read" && event.toolName !== "write" && event.toolName !== "edit") return;
clearIdleShutdownTimer();
const filePath = typeof input.path === "string" ? input.path : undefined;
if (!filePath) return;
const absPath = ensureActiveClientForFile(filePath, ctx.cwd);
if (!absPath) return;
void getOrCreateManager(ctx.cwd).getClientsForFile(absPath).catch(() => {});
});
pi.on("agent_start", async () => {
clearIdleShutdownTimer();
diagnosticsAbort?.abort();
diagnosticsAbort = null;
setActivity("idle");
touchedFiles.clear();
});
function agentWasAborted(event: any): boolean {
const messages = Array.isArray(event?.messages) ? event.messages : [];
return messages.some((m: any) =>
m &&
typeof m === "object" &&
(m as any).role === "assistant" &&
(((m as any).stopReason === "aborted") || ((m as any).stopReason === "error"))
);
}
pi.on("agent_end", async (event, ctx) => {
try {
if (hookMode !== "agent_end") return;
if (agentWasAborted(event)) {
// Don't run diagnostics on aborted/error runs.
touchedFiles.clear();
return;
}
if (touchedFiles.size === 0) return;
if (!ctx.isIdle() || ctx.hasPendingMessages()) return;
const abort = new AbortController();
diagnosticsAbort?.abort();
diagnosticsAbort = abort;
setActivity("working");
const files = Array.from(touchedFiles.entries());
touchedFiles.clear();
try {
const outputs: string[] = [];
for (const [filePath, includeWarnings] of files) {
if (shuttingDown || abort.signal.aborted) return;
if (!ctx.isIdle() || ctx.hasPendingMessages()) {
abort.abort();
return;
}
const output = await collectDiagnostics(filePath, ctx, includeWarnings, true, false);
if (abort.signal.aborted) return;
if (output) outputs.push(output);
}
if (shuttingDown || abort.signal.aborted) return;
if (outputs.length) {
pi.sendMessage({
customType: "lsp-diagnostics",
content: outputs.join("\n"),
display: true,
}, {
triggerTurn: true,
deliverAs: "followUp",
});
}
} finally {
if (diagnosticsAbort === abort) diagnosticsAbort = null;
if (!shuttingDown) setActivity("idle");
}
} finally {
if (!shuttingDown) scheduleIdleShutdown();
}
});
pi.on("tool_result", async (event, ctx) => {
if (event.toolName !== "write" && event.toolName !== "edit") return;
const filePath = event.input.path as string;
if (!filePath) return;
const absPath = ensureActiveClientForFile(filePath, ctx.cwd);
if (!absPath) return;
if (hookMode === "disabled") return;
if (hookMode === "agent_end") {
const includeWarnings = event.toolName === "write";
const existing = touchedFiles.get(absPath) ?? false;
touchedFiles.set(absPath, existing || includeWarnings);
return;
}
const includeWarnings = event.toolName === "write";
const output = await collectDiagnostics(absPath, ctx, includeWarnings, false);
if (!output) return;
return { content: [...event.content, { type: "text" as const, text: output }] as Array<{ type: "text"; text: string }> };
});
}

View File

@@ -0,0 +1,54 @@
{
"name": "lsp-pi",
"version": "1.0.3",
"description": "LSP extension for pi-coding-agent - provides language server tool and diagnostics feedback for Dart/Flutter, TypeScript, Vue, Svelte, Python, Go, Kotlin, Swift, Rust",
"scripts": {
"test": "npx tsx tests/lsp.test.ts",
"test:tool": "npx tsx tests/index.test.ts",
"test:integration": "npx tsx tests/lsp-integration.test.ts",
"test:all": "npm test && npm run test:tool && npm run test:integration"
},
"keywords": [
"lsp",
"language-server",
"dart",
"flutter",
"typescript",
"vue",
"svelte",
"python",
"go",
"kotlin",
"swift",
"rust",
"pi-coding-agent",
"extension",
"pi-package"
],
"author": "",
"license": "MIT",
"type": "module",
"pi": {
"extensions": [
"./lsp.ts",
"./lsp-tool.ts"
]
},
"dependencies": {
"@sinclair/typebox": "^0.34.33",
"vscode-languageserver-protocol": "^3.17.5"
},
"peerDependencies": {
"@mariozechner/pi-ai": "^0.50.0",
"@mariozechner/pi-coding-agent": "^0.50.0",
"@mariozechner/pi-tui": "^0.50.0"
},
"devDependencies": {
"@mariozechner/pi-ai": "^0.50.0",
"@mariozechner/pi-coding-agent": "^0.50.0",
"@mariozechner/pi-tui": "^0.50.0",
"@types/node": "^24.10.2",
"tsx": "^4.21.0",
"typescript": "^5.9.3"
}
}

View File

@@ -0,0 +1,235 @@
/**
* Unit tests for index.ts formatting functions
*/
// ============================================================================
// Test utilities
// ============================================================================
const tests: Array<{ name: string; fn: () => void | Promise<void> }> = [];
function test(name: string, fn: () => void | Promise<void>) {
tests.push({ name, fn });
}
function assertEqual<T>(actual: T, expected: T, message?: string) {
const a = JSON.stringify(actual);
const e = JSON.stringify(expected);
if (a !== e) throw new Error(message || `Expected ${e}, got ${a}`);
}
// ============================================================================
// Import the module to test internal functions
// We need to test via the execute function since formatters are private
// Or we can extract and test the logic directly
// ============================================================================
import { uriToPath, findSymbolPosition, formatDiagnostic, filterDiagnosticsBySeverity } from "../lsp-core.js";
// ============================================================================
// uriToPath tests
// ============================================================================
test("uriToPath: converts file:// URI to path", () => {
const result = uriToPath("file:///Users/test/file.ts");
assertEqual(result, "/Users/test/file.ts");
});
test("uriToPath: handles encoded characters", () => {
const result = uriToPath("file:///Users/test/my%20file.ts");
assertEqual(result, "/Users/test/my file.ts");
});
test("uriToPath: passes through non-file URIs", () => {
const result = uriToPath("/some/path.ts");
assertEqual(result, "/some/path.ts");
});
test("uriToPath: handles invalid URIs gracefully", () => {
const result = uriToPath("not-a-valid-uri");
assertEqual(result, "not-a-valid-uri");
});
// ============================================================================
// findSymbolPosition tests
// ============================================================================
test("findSymbolPosition: finds exact match", () => {
const symbols = [
{ name: "greet", range: { start: { line: 5, character: 10 }, end: { line: 5, character: 15 } }, selectionRange: { start: { line: 5, character: 10 }, end: { line: 5, character: 15 } }, kind: 12, children: [] },
{ name: "hello", range: { start: { line: 10, character: 0 }, end: { line: 10, character: 5 } }, selectionRange: { start: { line: 10, character: 0 }, end: { line: 10, character: 5 } }, kind: 12, children: [] },
];
const pos = findSymbolPosition(symbols as any, "greet");
assertEqual(pos, { line: 5, character: 10 });
});
test("findSymbolPosition: finds partial match", () => {
const symbols = [
{ name: "getUserName", range: { start: { line: 3, character: 0 }, end: { line: 3, character: 11 } }, selectionRange: { start: { line: 3, character: 0 }, end: { line: 3, character: 11 } }, kind: 12, children: [] },
];
const pos = findSymbolPosition(symbols as any, "user");
assertEqual(pos, { line: 3, character: 0 });
});
test("findSymbolPosition: prefers exact over partial", () => {
const symbols = [
{ name: "userName", range: { start: { line: 1, character: 0 }, end: { line: 1, character: 8 } }, selectionRange: { start: { line: 1, character: 0 }, end: { line: 1, character: 8 } }, kind: 12, children: [] },
{ name: "user", range: { start: { line: 5, character: 0 }, end: { line: 5, character: 4 } }, selectionRange: { start: { line: 5, character: 0 }, end: { line: 5, character: 4 } }, kind: 12, children: [] },
];
const pos = findSymbolPosition(symbols as any, "user");
assertEqual(pos, { line: 5, character: 0 });
});
test("findSymbolPosition: searches nested children", () => {
const symbols = [
{
name: "MyClass",
range: { start: { line: 0, character: 0 }, end: { line: 10, character: 0 } },
selectionRange: { start: { line: 0, character: 0 }, end: { line: 0, character: 7 } },
kind: 5,
children: [
{ name: "myMethod", range: { start: { line: 2, character: 2 }, end: { line: 4, character: 2 } }, selectionRange: { start: { line: 2, character: 2 }, end: { line: 2, character: 10 } }, kind: 6, children: [] },
]
},
];
const pos = findSymbolPosition(symbols as any, "myMethod");
assertEqual(pos, { line: 2, character: 2 });
});
test("findSymbolPosition: returns null for no match", () => {
const symbols = [
{ name: "foo", range: { start: { line: 0, character: 0 }, end: { line: 0, character: 3 } }, selectionRange: { start: { line: 0, character: 0 }, end: { line: 0, character: 3 } }, kind: 12, children: [] },
];
const pos = findSymbolPosition(symbols as any, "bar");
assertEqual(pos, null);
});
test("findSymbolPosition: case insensitive", () => {
const symbols = [
{ name: "MyFunction", range: { start: { line: 0, character: 0 }, end: { line: 0, character: 10 } }, selectionRange: { start: { line: 0, character: 0 }, end: { line: 0, character: 10 } }, kind: 12, children: [] },
];
const pos = findSymbolPosition(symbols as any, "myfunction");
assertEqual(pos, { line: 0, character: 0 });
});
// ============================================================================
// formatDiagnostic tests
// ============================================================================
test("formatDiagnostic: formats error", () => {
const diag = {
range: { start: { line: 5, character: 10 }, end: { line: 5, character: 15 } },
message: "Type 'number' is not assignable to type 'string'",
severity: 1,
};
const result = formatDiagnostic(diag as any);
assertEqual(result, "ERROR [6:11] Type 'number' is not assignable to type 'string'");
});
test("formatDiagnostic: formats warning", () => {
const diag = {
range: { start: { line: 0, character: 0 }, end: { line: 0, character: 5 } },
message: "Unused variable",
severity: 2,
};
const result = formatDiagnostic(diag as any);
assertEqual(result, "WARN [1:1] Unused variable");
});
test("formatDiagnostic: formats info", () => {
const diag = {
range: { start: { line: 2, character: 4 }, end: { line: 2, character: 10 } },
message: "Consider using const",
severity: 3,
};
const result = formatDiagnostic(diag as any);
assertEqual(result, "INFO [3:5] Consider using const");
});
test("formatDiagnostic: formats hint", () => {
const diag = {
range: { start: { line: 0, character: 0 }, end: { line: 0, character: 1 } },
message: "Prefer arrow function",
severity: 4,
};
const result = formatDiagnostic(diag as any);
assertEqual(result, "HINT [1:1] Prefer arrow function");
});
// ============================================================================
// filterDiagnosticsBySeverity tests
// ============================================================================
test("filterDiagnosticsBySeverity: all returns everything", () => {
const diags = [
{ severity: 1, message: "error", range: { start: { line: 0, character: 0 }, end: { line: 0, character: 0 } } },
{ severity: 2, message: "warning", range: { start: { line: 0, character: 0 }, end: { line: 0, character: 0 } } },
{ severity: 3, message: "info", range: { start: { line: 0, character: 0 }, end: { line: 0, character: 0 } } },
{ severity: 4, message: "hint", range: { start: { line: 0, character: 0 }, end: { line: 0, character: 0 } } },
];
const result = filterDiagnosticsBySeverity(diags as any, "all");
assertEqual(result.length, 4);
});
test("filterDiagnosticsBySeverity: error returns only errors", () => {
const diags = [
{ severity: 1, message: "error", range: { start: { line: 0, character: 0 }, end: { line: 0, character: 0 } } },
{ severity: 2, message: "warning", range: { start: { line: 0, character: 0 }, end: { line: 0, character: 0 } } },
];
const result = filterDiagnosticsBySeverity(diags as any, "error");
assertEqual(result.length, 1);
assertEqual(result[0].message, "error");
});
test("filterDiagnosticsBySeverity: warning returns errors and warnings", () => {
const diags = [
{ severity: 1, message: "error", range: { start: { line: 0, character: 0 }, end: { line: 0, character: 0 } } },
{ severity: 2, message: "warning", range: { start: { line: 0, character: 0 }, end: { line: 0, character: 0 } } },
{ severity: 3, message: "info", range: { start: { line: 0, character: 0 }, end: { line: 0, character: 0 } } },
];
const result = filterDiagnosticsBySeverity(diags as any, "warning");
assertEqual(result.length, 2);
});
test("filterDiagnosticsBySeverity: info returns errors, warnings, and info", () => {
const diags = [
{ severity: 1, message: "error", range: { start: { line: 0, character: 0 }, end: { line: 0, character: 0 } } },
{ severity: 2, message: "warning", range: { start: { line: 0, character: 0 }, end: { line: 0, character: 0 } } },
{ severity: 3, message: "info", range: { start: { line: 0, character: 0 }, end: { line: 0, character: 0 } } },
{ severity: 4, message: "hint", range: { start: { line: 0, character: 0 }, end: { line: 0, character: 0 } } },
];
const result = filterDiagnosticsBySeverity(diags as any, "info");
assertEqual(result.length, 3);
});
// ============================================================================
// Run tests
// ============================================================================
async function runTests(): Promise<void> {
console.log("Running index.ts unit tests...\n");
let passed = 0;
let failed = 0;
for (const { name, fn } of tests) {
try {
await fn();
console.log(` ${name}... ✓`);
passed++;
} catch (error) {
const msg = error instanceof Error ? error.message : String(error);
console.log(` ${name}... ✗`);
console.log(` Error: ${msg}\n`);
failed++;
}
}
console.log(`\n${passed} passed, ${failed} failed`);
if (failed > 0) {
process.exit(1);
}
}
runTests();

View File

@@ -0,0 +1,602 @@
/**
* Integration tests for LSP - spawns real language servers and detects errors
*
* Run with: npm run test:integration
*
* Skips tests if language server is not installed.
*/
// Suppress stream errors from vscode-jsonrpc when LSP process exits
process.on('uncaughtException', (err) => {
if (err.message?.includes('write after end')) return;
console.error('Uncaught:', err);
process.exit(1);
});
import { mkdtemp, rm, writeFile, mkdir } from "fs/promises";
import { existsSync, statSync } from "fs";
import { tmpdir } from "os";
import { join, delimiter } from "path";
import { LSPManager } from "../lsp-core.js";
// ============================================================================
// Test utilities
// ============================================================================
const tests: Array<{ name: string; fn: () => Promise<void> }> = [];
let skipped = 0;
function test(name: string, fn: () => Promise<void>) {
tests.push({ name, fn });
}
function assert(condition: boolean, message: string) {
if (!condition) throw new Error(message);
}
class SkipTest extends Error {
constructor(reason: string) {
super(reason);
this.name = "SkipTest";
}
}
function skip(reason: string): never {
throw new SkipTest(reason);
}
// Search paths matching lsp-core.ts
const SEARCH_PATHS = [
...(process.env.PATH?.split(delimiter) || []),
"/usr/local/bin",
"/opt/homebrew/bin",
`${process.env.HOME || ""}/.pub-cache/bin`,
`${process.env.HOME || ""}/fvm/default/bin`,
`${process.env.HOME || ""}/go/bin`,
`${process.env.HOME || ""}/.cargo/bin`,
];
function commandExists(cmd: string): boolean {
for (const dir of SEARCH_PATHS) {
const full = join(dir, cmd);
try {
if (existsSync(full) && statSync(full).isFile()) return true;
} catch {}
}
return false;
}
// ============================================================================
// TypeScript
// ============================================================================
test("typescript: detects type errors", async () => {
if (!commandExists("typescript-language-server")) {
skip("typescript-language-server not installed");
}
const dir = await mkdtemp(join(tmpdir(), "lsp-ts-"));
const manager = new LSPManager(dir);
try {
await writeFile(join(dir, "package.json"), "{}");
await writeFile(join(dir, "tsconfig.json"), JSON.stringify({
compilerOptions: { strict: true, noEmit: true }
}));
// Code with type error
const file = join(dir, "index.ts");
await writeFile(file, `const x: string = 123;`);
const { diagnostics } = await manager.touchFileAndWait(file, 10000);
assert(diagnostics.length > 0, `Expected errors, got ${diagnostics.length}`);
assert(
diagnostics.some(d => d.message.toLowerCase().includes("type") || d.severity === 1),
`Expected type error, got: ${diagnostics.map(d => d.message).join(", ")}`
);
} finally {
await manager.shutdown();
await rm(dir, { recursive: true, force: true }).catch(() => {});
}
});
test("typescript: valid code has no errors", async () => {
if (!commandExists("typescript-language-server")) {
skip("typescript-language-server not installed");
}
const dir = await mkdtemp(join(tmpdir(), "lsp-ts-"));
const manager = new LSPManager(dir);
try {
await writeFile(join(dir, "package.json"), "{}");
await writeFile(join(dir, "tsconfig.json"), JSON.stringify({
compilerOptions: { strict: true, noEmit: true }
}));
const file = join(dir, "index.ts");
await writeFile(file, `const x: string = "hello";`);
const { diagnostics } = await manager.touchFileAndWait(file, 10000);
const errors = diagnostics.filter(d => d.severity === 1);
assert(errors.length === 0, `Expected no errors, got: ${errors.map(d => d.message).join(", ")}`);
} finally {
await manager.shutdown();
await rm(dir, { recursive: true, force: true }).catch(() => {});
}
});
// ============================================================================
// Dart
// ============================================================================
test("dart: detects type errors", async () => {
if (!commandExists("dart")) {
skip("dart not installed");
}
const dir = await mkdtemp(join(tmpdir(), "lsp-dart-"));
const manager = new LSPManager(dir);
try {
await writeFile(join(dir, "pubspec.yaml"), "name: test_app\nenvironment:\n sdk: ^3.0.0");
await mkdir(join(dir, "lib"));
const file = join(dir, "lib/main.dart");
// Type error: assigning int to String
await writeFile(file, `
void main() {
String x = 123;
print(x);
}
`);
const { diagnostics } = await manager.touchFileAndWait(file, 15000);
assert(diagnostics.length > 0, `Expected errors, got ${diagnostics.length}`);
} finally {
await manager.shutdown();
await rm(dir, { recursive: true, force: true }).catch(() => {});
}
});
test("dart: valid code has no errors", async () => {
if (!commandExists("dart")) {
skip("dart not installed");
}
const dir = await mkdtemp(join(tmpdir(), "lsp-dart-"));
const manager = new LSPManager(dir);
try {
await writeFile(join(dir, "pubspec.yaml"), "name: test_app\nenvironment:\n sdk: ^3.0.0");
await mkdir(join(dir, "lib"));
const file = join(dir, "lib/main.dart");
await writeFile(file, `
void main() {
String x = "hello";
print(x);
}
`);
const { diagnostics } = await manager.touchFileAndWait(file, 15000);
const errors = diagnostics.filter(d => d.severity === 1);
assert(errors.length === 0, `Expected no errors, got: ${errors.map(d => d.message).join(", ")}`);
} finally {
await manager.shutdown();
await rm(dir, { recursive: true, force: true }).catch(() => {});
}
});
// ============================================================================
// Rust
// ============================================================================
test("rust: detects type errors", async () => {
if (!commandExists("rust-analyzer")) {
skip("rust-analyzer not installed");
}
const dir = await mkdtemp(join(tmpdir(), "lsp-rust-"));
const manager = new LSPManager(dir);
try {
await writeFile(join(dir, "Cargo.toml"), `[package]\nname = "test"\nversion = "0.1.0"\nedition = "2021"`);
await mkdir(join(dir, "src"));
const file = join(dir, "src/main.rs");
await writeFile(file, `fn main() {\n let x: i32 = "hello";\n}`);
// rust-analyzer needs a LOT of time to initialize (compiles the project)
const { diagnostics } = await manager.touchFileAndWait(file, 60000);
assert(diagnostics.length > 0, `Expected errors, got ${diagnostics.length}`);
} finally {
await manager.shutdown();
await rm(dir, { recursive: true, force: true }).catch(() => {});
}
});
test("rust: valid code has no errors", async () => {
if (!commandExists("rust-analyzer")) {
skip("rust-analyzer not installed");
}
const dir = await mkdtemp(join(tmpdir(), "lsp-rust-"));
const manager = new LSPManager(dir);
try {
await writeFile(join(dir, "Cargo.toml"), `[package]\nname = "test"\nversion = "0.1.0"\nedition = "2021"`);
await mkdir(join(dir, "src"));
const file = join(dir, "src/main.rs");
await writeFile(file, `fn main() {\n let x = "hello";\n println!("{}", x);\n}`);
const { diagnostics } = await manager.touchFileAndWait(file, 60000);
const errors = diagnostics.filter(d => d.severity === 1);
assert(errors.length === 0, `Expected no errors, got: ${errors.map(d => d.message).join(", ")}`);
} finally {
await manager.shutdown();
await rm(dir, { recursive: true, force: true }).catch(() => {});
}
});
// ============================================================================
// Go
// ============================================================================
test("go: detects type errors", async () => {
if (!commandExists("gopls")) {
skip("gopls not installed");
}
const dir = await mkdtemp(join(tmpdir(), "lsp-go-"));
const manager = new LSPManager(dir);
try {
await writeFile(join(dir, "go.mod"), "module test\n\ngo 1.21");
const file = join(dir, "main.go");
// Type error: cannot use int as string
await writeFile(file, `package main
func main() {
var x string = 123
println(x)
}
`);
const { diagnostics } = await manager.touchFileAndWait(file, 15000);
assert(diagnostics.length > 0, `Expected errors, got ${diagnostics.length}`);
} finally {
await manager.shutdown();
await rm(dir, { recursive: true, force: true }).catch(() => {});
}
});
test("go: valid code has no errors", async () => {
if (!commandExists("gopls")) {
skip("gopls not installed");
}
const dir = await mkdtemp(join(tmpdir(), "lsp-go-"));
const manager = new LSPManager(dir);
try {
await writeFile(join(dir, "go.mod"), "module test\n\ngo 1.21");
const file = join(dir, "main.go");
await writeFile(file, `package main
func main() {
var x string = "hello"
println(x)
}
`);
const { diagnostics } = await manager.touchFileAndWait(file, 15000);
const errors = diagnostics.filter(d => d.severity === 1);
assert(errors.length === 0, `Expected no errors, got: ${errors.map(d => d.message).join(", ")}`);
} finally {
await manager.shutdown();
await rm(dir, { recursive: true, force: true }).catch(() => {});
}
});
// ============================================================================
// Kotlin
// ============================================================================
test("kotlin: detects syntax errors", async () => {
if (!commandExists("kotlin-language-server")) {
skip("kotlin-language-server not installed");
}
const dir = await mkdtemp(join(tmpdir(), "lsp-kt-"));
const manager = new LSPManager(dir);
try {
// Minimal Gradle markers so the LSP picks a root
await writeFile(join(dir, "settings.gradle.kts"), "rootProject.name = \"test\"\n");
await writeFile(join(dir, "build.gradle.kts"), "// empty\n");
await mkdir(join(dir, "src/main/kotlin"), { recursive: true });
const file = join(dir, "src/main/kotlin/Main.kt");
// Syntax error
await writeFile(file, "fun main() { val x = }\n");
const { diagnostics, receivedResponse } = await manager.touchFileAndWait(file, 30000);
assert(receivedResponse, "Expected Kotlin LSP to respond");
assert(diagnostics.length > 0, `Expected errors, got ${diagnostics.length}`);
} finally {
await manager.shutdown();
await rm(dir, { recursive: true, force: true }).catch(() => {});
}
});
test("kotlin: valid code has no errors", async () => {
if (!commandExists("kotlin-language-server")) {
skip("kotlin-language-server not installed");
}
const dir = await mkdtemp(join(tmpdir(), "lsp-kt-"));
const manager = new LSPManager(dir);
try {
await writeFile(join(dir, "settings.gradle.kts"), "rootProject.name = \"test\"\n");
await writeFile(join(dir, "build.gradle.kts"), "// empty\n");
await mkdir(join(dir, "src/main/kotlin"), { recursive: true });
const file = join(dir, "src/main/kotlin/Main.kt");
await writeFile(file, "fun main() { val x = 1; println(x) }\n");
const { diagnostics, receivedResponse } = await manager.touchFileAndWait(file, 30000);
assert(receivedResponse, "Expected Kotlin LSP to respond");
const errors = diagnostics.filter(d => d.severity === 1);
assert(errors.length === 0, `Expected no errors, got: ${errors.map(d => d.message).join(", ")}`);
} finally {
await manager.shutdown();
await rm(dir, { recursive: true, force: true }).catch(() => {});
}
});
// ============================================================================
// Python
// ============================================================================
test("python: detects type errors", async () => {
if (!commandExists("pyright-langserver")) {
skip("pyright-langserver not installed");
}
const dir = await mkdtemp(join(tmpdir(), "lsp-py-"));
const manager = new LSPManager(dir);
try {
await writeFile(join(dir, "pyproject.toml"), `[project]\nname = "test"`);
const file = join(dir, "main.py");
// Type error with type annotation
await writeFile(file, `
def greet(name: str) -> str:
return "Hello, " + name
x: str = 123 # Type error
result = greet(456) # Type error
`);
const { diagnostics } = await manager.touchFileAndWait(file, 10000);
assert(diagnostics.length > 0, `Expected errors, got ${diagnostics.length}`);
} finally {
await manager.shutdown();
await rm(dir, { recursive: true, force: true }).catch(() => {});
}
});
test("python: valid code has no errors", async () => {
if (!commandExists("pyright-langserver")) {
skip("pyright-langserver not installed");
}
const dir = await mkdtemp(join(tmpdir(), "lsp-py-"));
const manager = new LSPManager(dir);
try {
await writeFile(join(dir, "pyproject.toml"), `[project]\nname = "test"`);
const file = join(dir, "main.py");
await writeFile(file, `
def greet(name: str) -> str:
return "Hello, " + name
x: str = "world"
result = greet(x)
`);
const { diagnostics } = await manager.touchFileAndWait(file, 10000);
const errors = diagnostics.filter(d => d.severity === 1);
assert(errors.length === 0, `Expected no errors, got: ${errors.map(d => d.message).join(", ")}`);
} finally {
await manager.shutdown();
await rm(dir, { recursive: true, force: true }).catch(() => {});
}
});
// ============================================================================
// Rename (TypeScript)
// ============================================================================
test("typescript: rename symbol", async () => {
if (!commandExists("typescript-language-server")) {
skip("typescript-language-server not installed");
}
const dir = await mkdtemp(join(tmpdir(), "lsp-ts-rename-"));
const manager = new LSPManager(dir);
try {
await writeFile(join(dir, "package.json"), "{}");
await writeFile(join(dir, "tsconfig.json"), JSON.stringify({
compilerOptions: { strict: true, noEmit: true }
}));
const file = join(dir, "index.ts");
await writeFile(file, `function greet(name: string) {
return "Hello, " + name;
}
const result = greet("world");
`);
// Touch file first to ensure it's loaded
await manager.touchFileAndWait(file, 10000);
// Rename 'greet' at line 1, col 10
const edit = await manager.rename(file, 1, 10, "sayHello");
if (!edit) throw new Error("Expected rename to return WorkspaceEdit");
assert(
edit.changes !== undefined || edit.documentChanges !== undefined,
"Expected changes or documentChanges in WorkspaceEdit"
);
// Should have edits for both the function definition and the call
const allEdits: any[] = [];
if (edit.changes) {
for (const edits of Object.values(edit.changes)) {
allEdits.push(...(edits as any[]));
}
}
if (edit.documentChanges) {
for (const change of edit.documentChanges as any[]) {
if (change.edits) allEdits.push(...change.edits);
}
}
assert(allEdits.length >= 2, `Expected at least 2 edits (definition + usage), got ${allEdits.length}`);
} finally {
await manager.shutdown();
await rm(dir, { recursive: true, force: true }).catch(() => {});
}
});
// ============================================================================
// Code Actions (TypeScript)
// ============================================================================
test("typescript: get code actions for error", async () => {
if (!commandExists("typescript-language-server")) {
skip("typescript-language-server not installed");
}
const dir = await mkdtemp(join(tmpdir(), "lsp-ts-actions-"));
const manager = new LSPManager(dir);
try {
await writeFile(join(dir, "package.json"), "{}");
await writeFile(join(dir, "tsconfig.json"), JSON.stringify({
compilerOptions: { strict: true, noEmit: true }
}));
const file = join(dir, "index.ts");
// Missing import - should offer "Add import" code action
await writeFile(file, `const x: Promise<string> = Promise.resolve("hello");
console.log(x);
`);
// Touch to get diagnostics first
await manager.touchFileAndWait(file, 10000);
// Get code actions at line 1
const actions = await manager.getCodeActions(file, 1, 1, 1, 50);
// May or may not have actions depending on the code, but shouldn't throw
assert(Array.isArray(actions), "Expected array of code actions");
} finally {
await manager.shutdown();
await rm(dir, { recursive: true, force: true }).catch(() => {});
}
});
test("typescript: code actions for missing function", async () => {
if (!commandExists("typescript-language-server")) {
skip("typescript-language-server not installed");
}
const dir = await mkdtemp(join(tmpdir(), "lsp-ts-actions2-"));
const manager = new LSPManager(dir);
try {
await writeFile(join(dir, "package.json"), "{}");
await writeFile(join(dir, "tsconfig.json"), JSON.stringify({
compilerOptions: { strict: true, noEmit: true }
}));
const file = join(dir, "index.ts");
// Call undefined function - should offer quick fix
await writeFile(file, `const result = undefinedFunction();
`);
await manager.touchFileAndWait(file, 10000);
// Get code actions where the error is
const actions = await manager.getCodeActions(file, 1, 16, 1, 33);
// TypeScript should offer to create the function
assert(Array.isArray(actions), "Expected array of code actions");
// Note: we don't assert on action count since it depends on TS version
} finally {
await manager.shutdown();
await rm(dir, { recursive: true, force: true }).catch(() => {});
}
});
// ============================================================================
// Run tests
// ============================================================================
async function runTests(): Promise<void> {
console.log("Running LSP integration tests...\n");
console.log("Note: Tests are skipped if language server is not installed.\n");
let passed = 0;
let failed = 0;
for (const { name, fn } of tests) {
try {
await fn();
console.log(` ${name}... ✓`);
passed++;
} catch (error) {
if (error instanceof SkipTest) {
console.log(` ${name}... ⊘ (${error.message})`);
skipped++;
} else {
const msg = error instanceof Error ? error.message : String(error);
console.log(` ${name}... ✗`);
console.log(` Error: ${msg}\n`);
failed++;
}
}
}
console.log(`\n${passed} passed, ${failed} failed, ${skipped} skipped`);
if (failed > 0) {
process.exit(1);
}
}
runTests();

View File

@@ -0,0 +1,898 @@
/**
* Tests for LSP hook - configuration and utility functions
*
* Run with: npm test
*
* These tests cover:
* - Project root detection for various languages
* - Language ID mappings
* - URI construction
* - Server configuration correctness
*/
import { mkdtemp, rm, writeFile, mkdir } from "fs/promises";
import { tmpdir } from "os";
import { join } from "path";
import { pathToFileURL } from "url";
import { LSP_SERVERS, LANGUAGE_IDS } from "../lsp-core.js";
// ============================================================================
// Test utilities
// ============================================================================
interface TestResult {
name: string;
passed: boolean;
error?: string;
}
const tests: Array<{ name: string; fn: () => Promise<void> }> = [];
function test(name: string, fn: () => Promise<void>) {
tests.push({ name, fn });
}
function assert(condition: boolean, message: string) {
if (!condition) throw new Error(message);
}
function assertEquals<T>(actual: T, expected: T, message: string) {
assert(
actual === expected,
`${message}\nExpected: ${JSON.stringify(expected)}\nActual: ${JSON.stringify(actual)}`
);
}
function assertIncludes(arr: string[], item: string, message: string) {
assert(arr.includes(item), `${message}\nArray: [${arr.join(", ")}]\nMissing: ${item}`);
}
/** Create a temp directory with optional file structure */
async function withTempDir(
structure: Record<string, string | null>, // null = directory, string = file content
fn: (dir: string) => Promise<void>
): Promise<void> {
const dir = await mkdtemp(join(tmpdir(), "lsp-test-"));
try {
for (const [path, content] of Object.entries(structure)) {
const fullPath = join(dir, path);
if (content === null) {
await mkdir(fullPath, { recursive: true });
} else {
await mkdir(join(dir, path.split("/").slice(0, -1).join("/")), { recursive: true }).catch(() => {});
await writeFile(fullPath, content);
}
}
await fn(dir);
} finally {
await rm(dir, { recursive: true, force: true }).catch(() => {});
}
}
// ============================================================================
// Language ID tests
// ============================================================================
test("LANGUAGE_IDS: TypeScript extensions", async () => {
assertEquals(LANGUAGE_IDS[".ts"], "typescript", ".ts should map to typescript");
assertEquals(LANGUAGE_IDS[".tsx"], "typescriptreact", ".tsx should map to typescriptreact");
assertEquals(LANGUAGE_IDS[".mts"], "typescript", ".mts should map to typescript");
assertEquals(LANGUAGE_IDS[".cts"], "typescript", ".cts should map to typescript");
});
test("LANGUAGE_IDS: JavaScript extensions", async () => {
assertEquals(LANGUAGE_IDS[".js"], "javascript", ".js should map to javascript");
assertEquals(LANGUAGE_IDS[".jsx"], "javascriptreact", ".jsx should map to javascriptreact");
assertEquals(LANGUAGE_IDS[".mjs"], "javascript", ".mjs should map to javascript");
assertEquals(LANGUAGE_IDS[".cjs"], "javascript", ".cjs should map to javascript");
});
test("LANGUAGE_IDS: Dart extension", async () => {
assertEquals(LANGUAGE_IDS[".dart"], "dart", ".dart should map to dart");
});
test("LANGUAGE_IDS: Go extension", async () => {
assertEquals(LANGUAGE_IDS[".go"], "go", ".go should map to go");
});
test("LANGUAGE_IDS: Rust extension", async () => {
assertEquals(LANGUAGE_IDS[".rs"], "rust", ".rs should map to rust");
});
test("LANGUAGE_IDS: Kotlin extensions", async () => {
assertEquals(LANGUAGE_IDS[".kt"], "kotlin", ".kt should map to kotlin");
assertEquals(LANGUAGE_IDS[".kts"], "kotlin", ".kts should map to kotlin");
});
test("LANGUAGE_IDS: Swift extension", async () => {
assertEquals(LANGUAGE_IDS[".swift"], "swift", ".swift should map to swift");
});
test("LANGUAGE_IDS: Python extensions", async () => {
assertEquals(LANGUAGE_IDS[".py"], "python", ".py should map to python");
assertEquals(LANGUAGE_IDS[".pyi"], "python", ".pyi should map to python");
});
test("LANGUAGE_IDS: Vue/Svelte/Astro extensions", async () => {
assertEquals(LANGUAGE_IDS[".vue"], "vue", ".vue should map to vue");
assertEquals(LANGUAGE_IDS[".svelte"], "svelte", ".svelte should map to svelte");
assertEquals(LANGUAGE_IDS[".astro"], "astro", ".astro should map to astro");
});
// ============================================================================
// Server configuration tests
// ============================================================================
test("LSP_SERVERS: has TypeScript server", async () => {
const server = LSP_SERVERS.find(s => s.id === "typescript");
assert(server !== undefined, "Should have typescript server");
assertIncludes(server!.extensions, ".ts", "Should handle .ts");
assertIncludes(server!.extensions, ".tsx", "Should handle .tsx");
assertIncludes(server!.extensions, ".js", "Should handle .js");
assertIncludes(server!.extensions, ".jsx", "Should handle .jsx");
});
test("LSP_SERVERS: has Dart server", async () => {
const server = LSP_SERVERS.find(s => s.id === "dart");
assert(server !== undefined, "Should have dart server");
assertIncludes(server!.extensions, ".dart", "Should handle .dart");
});
test("LSP_SERVERS: has Rust Analyzer server", async () => {
const server = LSP_SERVERS.find(s => s.id === "rust-analyzer");
assert(server !== undefined, "Should have rust-analyzer server");
assertIncludes(server!.extensions, ".rs", "Should handle .rs");
});
test("LSP_SERVERS: has Gopls server", async () => {
const server = LSP_SERVERS.find(s => s.id === "gopls");
assert(server !== undefined, "Should have gopls server");
assertIncludes(server!.extensions, ".go", "Should handle .go");
});
test("LSP_SERVERS: has Kotlin server", async () => {
const server = LSP_SERVERS.find(s => s.id === "kotlin");
assert(server !== undefined, "Should have kotlin server");
assertIncludes(server!.extensions, ".kt", "Should handle .kt");
assertIncludes(server!.extensions, ".kts", "Should handle .kts");
});
test("LSP_SERVERS: has Swift server", async () => {
const server = LSP_SERVERS.find(s => s.id === "swift");
assert(server !== undefined, "Should have swift server");
assertIncludes(server!.extensions, ".swift", "Should handle .swift");
});
test("LSP_SERVERS: has Pyright server", async () => {
const server = LSP_SERVERS.find(s => s.id === "pyright");
assert(server !== undefined, "Should have pyright server");
assertIncludes(server!.extensions, ".py", "Should handle .py");
assertIncludes(server!.extensions, ".pyi", "Should handle .pyi");
});
// ============================================================================
// TypeScript root detection tests
// ============================================================================
test("typescript: finds root with package.json", async () => {
await withTempDir({
"package.json": "{}",
"src/index.ts": "export const x = 1;",
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "typescript")!;
const root = server.findRoot(join(dir, "src/index.ts"), dir);
assertEquals(root, dir, "Should find root at package.json location");
});
});
test("typescript: finds root with tsconfig.json", async () => {
await withTempDir({
"tsconfig.json": "{}",
"src/index.ts": "export const x = 1;",
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "typescript")!;
const root = server.findRoot(join(dir, "src/index.ts"), dir);
assertEquals(root, dir, "Should find root at tsconfig.json location");
});
});
test("typescript: finds root with jsconfig.json", async () => {
await withTempDir({
"jsconfig.json": "{}",
"src/app.js": "const x = 1;",
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "typescript")!;
const root = server.findRoot(join(dir, "src/app.js"), dir);
assertEquals(root, dir, "Should find root at jsconfig.json location");
});
});
test("typescript: returns undefined for deno projects", async () => {
await withTempDir({
"deno.json": "{}",
"main.ts": "console.log('deno');",
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "typescript")!;
const root = server.findRoot(join(dir, "main.ts"), dir);
assertEquals(root, undefined, "Should return undefined for deno projects");
});
});
test("typescript: nested package finds nearest root", async () => {
await withTempDir({
"package.json": "{}",
"packages/web/package.json": "{}",
"packages/web/src/index.ts": "export const x = 1;",
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "typescript")!;
const root = server.findRoot(join(dir, "packages/web/src/index.ts"), dir);
assertEquals(root, join(dir, "packages/web"), "Should find nearest package.json");
});
});
// ============================================================================
// Dart root detection tests
// ============================================================================
test("dart: finds root with pubspec.yaml", async () => {
await withTempDir({
"pubspec.yaml": "name: my_app",
"lib/main.dart": "void main() {}",
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "dart")!;
const root = server.findRoot(join(dir, "lib/main.dart"), dir);
assertEquals(root, dir, "Should find root at pubspec.yaml location");
});
});
test("dart: finds root with analysis_options.yaml", async () => {
await withTempDir({
"analysis_options.yaml": "linter: rules:",
"lib/main.dart": "void main() {}",
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "dart")!;
const root = server.findRoot(join(dir, "lib/main.dart"), dir);
assertEquals(root, dir, "Should find root at analysis_options.yaml location");
});
});
test("dart: nested package finds nearest root", async () => {
await withTempDir({
"pubspec.yaml": "name: monorepo",
"packages/core/pubspec.yaml": "name: core",
"packages/core/lib/core.dart": "void init() {}",
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "dart")!;
const root = server.findRoot(join(dir, "packages/core/lib/core.dart"), dir);
assertEquals(root, join(dir, "packages/core"), "Should find nearest pubspec.yaml");
});
});
// ============================================================================
// Rust root detection tests
// ============================================================================
test("rust: finds root with Cargo.toml", async () => {
await withTempDir({
"Cargo.toml": "[package]\nname = \"my_crate\"",
"src/lib.rs": "pub fn hello() {}",
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "rust-analyzer")!;
const root = server.findRoot(join(dir, "src/lib.rs"), dir);
assertEquals(root, dir, "Should find root at Cargo.toml location");
});
});
test("rust: nested workspace member finds nearest Cargo.toml", async () => {
await withTempDir({
"Cargo.toml": "[workspace]\nmembers = [\"crates/*\"]",
"crates/core/Cargo.toml": "[package]\nname = \"core\"",
"crates/core/src/lib.rs": "pub fn init() {}",
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "rust-analyzer")!;
const root = server.findRoot(join(dir, "crates/core/src/lib.rs"), dir);
assertEquals(root, join(dir, "crates/core"), "Should find nearest Cargo.toml");
});
});
// ============================================================================
// Go root detection tests (including gopls bug fix verification)
// ============================================================================
test("gopls: finds root with go.mod", async () => {
await withTempDir({
"go.mod": "module example.com/myapp",
"main.go": "package main",
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "gopls")!;
const root = server.findRoot(join(dir, "main.go"), dir);
assertEquals(root, dir, "Should find root at go.mod location");
});
});
test("gopls: finds root with go.work (workspace)", async () => {
await withTempDir({
"go.work": "go 1.21\nuse ./app",
"app/go.mod": "module example.com/app",
"app/main.go": "package main",
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "gopls")!;
const root = server.findRoot(join(dir, "app/main.go"), dir);
assertEquals(root, dir, "Should find root at go.work location (workspace root)");
});
});
test("gopls: prefers go.work over go.mod", async () => {
await withTempDir({
"go.work": "go 1.21\nuse ./app",
"go.mod": "module example.com/root",
"app/go.mod": "module example.com/app",
"app/main.go": "package main",
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "gopls")!;
const root = server.findRoot(join(dir, "app/main.go"), dir);
// go.work is found first, so it should return the go.work location
assertEquals(root, dir, "Should prefer go.work over go.mod");
});
});
test("gopls: returns undefined when no go.mod or go.work (bug fix verification)", async () => {
await withTempDir({
"main.go": "package main",
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "gopls")!;
const root = server.findRoot(join(dir, "main.go"), dir);
// This test verifies the bug fix: previously this would return undefined
// because `undefined !== cwd` was true, skipping the go.mod check
assertEquals(root, undefined, "Should return undefined when no go.mod or go.work");
});
});
test("gopls: finds go.mod when go.work not present (bug fix verification)", async () => {
await withTempDir({
"go.mod": "module example.com/myapp",
"cmd/server/main.go": "package main",
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "gopls")!;
const root = server.findRoot(join(dir, "cmd/server/main.go"), dir);
// This is the key test for the bug fix
// Previously: findRoot(go.work) returns undefined, then `undefined !== cwd` is true,
// so it would return undefined without checking go.mod
// After fix: if go.work not found, falls through to check go.mod
assertEquals(root, dir, "Should find go.mod when go.work is not present");
});
});
// ============================================================================
// Kotlin root detection tests
// ============================================================================
test("kotlin: finds root with settings.gradle.kts", async () => {
await withTempDir({
"settings.gradle.kts": "rootProject.name = \"myapp\"",
"app/src/main/kotlin/Main.kt": "fun main() {}",
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "kotlin")!;
const root = server.findRoot(join(dir, "app/src/main/kotlin/Main.kt"), dir);
assertEquals(root, dir, "Should find root at settings.gradle.kts location");
});
});
test("kotlin: prefers settings.gradle(.kts) over nested build.gradle", async () => {
await withTempDir({
"settings.gradle": "rootProject.name = 'root'",
"app/build.gradle": "plugins {}",
"app/src/main/kotlin/Main.kt": "fun main() {}",
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "kotlin")!;
const root = server.findRoot(join(dir, "app/src/main/kotlin/Main.kt"), dir);
assertEquals(root, dir, "Should prefer settings.gradle at workspace root");
});
});
test("kotlin: finds root with pom.xml", async () => {
await withTempDir({
"pom.xml": "<project></project>",
"src/main/kotlin/Main.kt": "fun main() {}",
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "kotlin")!;
const root = server.findRoot(join(dir, "src/main/kotlin/Main.kt"), dir);
assertEquals(root, dir, "Should find root at pom.xml location");
});
});
// ============================================================================
// Swift root detection tests
// ============================================================================
test("swift: finds root with Package.swift", async () => {
await withTempDir({
"Package.swift": "// swift-tools-version: 5.9",
"Sources/App/main.swift": "print(\"hi\")",
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "swift")!;
const root = server.findRoot(join(dir, "Sources/App/main.swift"), dir);
assertEquals(root, dir, "Should find root at Package.swift location");
});
});
test("swift: finds root with Xcode project", async () => {
await withTempDir({
"MyApp.xcodeproj/project.pbxproj": "// pbxproj",
"MyApp/main.swift": "print(\"hi\")",
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "swift")!;
const root = server.findRoot(join(dir, "MyApp/main.swift"), dir);
assertEquals(root, dir, "Should find root at Xcode project location");
});
});
test("swift: finds root with Xcode workspace", async () => {
await withTempDir({
"MyApp.xcworkspace/contents.xcworkspacedata": "<Workspace/>",
"MyApp/main.swift": "print(\"hi\")",
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "swift")!;
const root = server.findRoot(join(dir, "MyApp/main.swift"), dir);
assertEquals(root, dir, "Should find root at Xcode workspace location");
});
});
// ============================================================================
// Python root detection tests
// ============================================================================
test("pyright: finds root with pyproject.toml", async () => {
await withTempDir({
"pyproject.toml": "[project]\nname = \"myapp\"",
"src/main.py": "print('hello')",
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "pyright")!;
const root = server.findRoot(join(dir, "src/main.py"), dir);
assertEquals(root, dir, "Should find root at pyproject.toml location");
});
});
test("pyright: finds root with setup.py", async () => {
await withTempDir({
"setup.py": "from setuptools import setup",
"myapp/main.py": "print('hello')",
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "pyright")!;
const root = server.findRoot(join(dir, "myapp/main.py"), dir);
assertEquals(root, dir, "Should find root at setup.py location");
});
});
test("pyright: finds root with requirements.txt", async () => {
await withTempDir({
"requirements.txt": "flask>=2.0",
"app.py": "from flask import Flask",
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "pyright")!;
const root = server.findRoot(join(dir, "app.py"), dir);
assertEquals(root, dir, "Should find root at requirements.txt location");
});
});
// ============================================================================
// URI construction tests (pathToFileURL)
// ============================================================================
test("pathToFileURL: handles simple paths", async () => {
const uri = pathToFileURL("/home/user/project/file.ts").href;
assertEquals(uri, "file:///home/user/project/file.ts", "Should create proper file URI");
});
test("pathToFileURL: encodes special characters", async () => {
const uri = pathToFileURL("/home/user/my project/file.ts").href;
assert(uri.includes("my%20project"), "Should URL-encode spaces");
});
test("pathToFileURL: handles unicode", async () => {
const uri = pathToFileURL("/home/user/项目/file.ts").href;
// pathToFileURL properly encodes unicode
assert(uri.startsWith("file:///"), "Should start with file:///");
assert(uri.includes("file.ts"), "Should contain filename");
});
// ============================================================================
// Vue/Svelte root detection tests
// ============================================================================
test("vue: finds root with package.json", async () => {
await withTempDir({
"package.json": "{}",
"src/App.vue": "<template></template>",
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "vue")!;
const root = server.findRoot(join(dir, "src/App.vue"), dir);
assertEquals(root, dir, "Should find root at package.json location");
});
});
test("vue: finds root with vite.config.ts", async () => {
await withTempDir({
"vite.config.ts": "export default {}",
"src/App.vue": "<template></template>",
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "vue")!;
const root = server.findRoot(join(dir, "src/App.vue"), dir);
assertEquals(root, dir, "Should find root at vite.config.ts location");
});
});
test("svelte: finds root with svelte.config.js", async () => {
await withTempDir({
"svelte.config.js": "export default {}",
"src/App.svelte": "<script></script>",
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "svelte")!;
const root = server.findRoot(join(dir, "src/App.svelte"), dir);
assertEquals(root, dir, "Should find root at svelte.config.js location");
});
});
// ============================================================================
// Additional Rust tests (parity with TypeScript)
// ============================================================================
test("rust: finds root in src subdirectory", async () => {
await withTempDir({
"Cargo.toml": "[package]\nname = \"myapp\"",
"src/main.rs": "fn main() {}",
"src/lib.rs": "pub mod utils;",
"src/utils/mod.rs": "pub fn helper() {}",
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "rust-analyzer")!;
const root = server.findRoot(join(dir, "src/utils/mod.rs"), dir);
assertEquals(root, dir, "Should find root from deeply nested src file");
});
});
test("rust: workspace with multiple crates", async () => {
await withTempDir({
"Cargo.toml": "[workspace]\nmembers = [\"crates/*\"]",
"crates/api/Cargo.toml": "[package]\nname = \"api\"",
"crates/api/src/lib.rs": "pub fn serve() {}",
"crates/core/Cargo.toml": "[package]\nname = \"core\"",
"crates/core/src/lib.rs": "pub fn init() {}",
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "rust-analyzer")!;
// Each crate should find its own Cargo.toml
const apiRoot = server.findRoot(join(dir, "crates/api/src/lib.rs"), dir);
const coreRoot = server.findRoot(join(dir, "crates/core/src/lib.rs"), dir);
assertEquals(apiRoot, join(dir, "crates/api"), "API crate should find its Cargo.toml");
assertEquals(coreRoot, join(dir, "crates/core"), "Core crate should find its Cargo.toml");
});
});
test("rust: returns undefined when no Cargo.toml", async () => {
await withTempDir({
"main.rs": "fn main() {}",
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "rust-analyzer")!;
const root = server.findRoot(join(dir, "main.rs"), dir);
assertEquals(root, undefined, "Should return undefined when no Cargo.toml");
});
});
// ============================================================================
// Additional Dart tests (parity with TypeScript)
// ============================================================================
test("dart: Flutter project with pubspec.yaml", async () => {
await withTempDir({
"pubspec.yaml": "name: my_flutter_app\ndependencies:\n flutter:\n sdk: flutter",
"lib/main.dart": "import 'package:flutter/material.dart';",
"lib/screens/home.dart": "class HomeScreen {}",
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "dart")!;
const root = server.findRoot(join(dir, "lib/screens/home.dart"), dir);
assertEquals(root, dir, "Should find root for Flutter project");
});
});
test("dart: returns undefined when no marker files", async () => {
await withTempDir({
"main.dart": "void main() {}",
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "dart")!;
const root = server.findRoot(join(dir, "main.dart"), dir);
assertEquals(root, undefined, "Should return undefined when no pubspec.yaml or analysis_options.yaml");
});
});
test("dart: monorepo with multiple packages", async () => {
await withTempDir({
"pubspec.yaml": "name: monorepo",
"packages/auth/pubspec.yaml": "name: auth",
"packages/auth/lib/auth.dart": "class Auth {}",
"packages/ui/pubspec.yaml": "name: ui",
"packages/ui/lib/widgets.dart": "class Button {}",
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "dart")!;
const authRoot = server.findRoot(join(dir, "packages/auth/lib/auth.dart"), dir);
const uiRoot = server.findRoot(join(dir, "packages/ui/lib/widgets.dart"), dir);
assertEquals(authRoot, join(dir, "packages/auth"), "Auth package should find its pubspec");
assertEquals(uiRoot, join(dir, "packages/ui"), "UI package should find its pubspec");
});
});
// ============================================================================
// Additional Python tests (parity with TypeScript)
// ============================================================================
test("pyright: finds root with pyrightconfig.json", async () => {
await withTempDir({
"pyrightconfig.json": "{}",
"src/app.py": "print('hello')",
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "pyright")!;
const root = server.findRoot(join(dir, "src/app.py"), dir);
assertEquals(root, dir, "Should find root at pyrightconfig.json location");
});
});
test("pyright: returns undefined when no marker files", async () => {
await withTempDir({
"script.py": "print('hello')",
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "pyright")!;
const root = server.findRoot(join(dir, "script.py"), dir);
assertEquals(root, undefined, "Should return undefined when no Python project markers");
});
});
test("pyright: monorepo with multiple packages", async () => {
await withTempDir({
"pyproject.toml": "[project]\nname = \"monorepo\"",
"packages/api/pyproject.toml": "[project]\nname = \"api\"",
"packages/api/src/main.py": "from flask import Flask",
"packages/worker/pyproject.toml": "[project]\nname = \"worker\"",
"packages/worker/src/tasks.py": "def process(): pass",
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "pyright")!;
const apiRoot = server.findRoot(join(dir, "packages/api/src/main.py"), dir);
const workerRoot = server.findRoot(join(dir, "packages/worker/src/tasks.py"), dir);
assertEquals(apiRoot, join(dir, "packages/api"), "API package should find its pyproject.toml");
assertEquals(workerRoot, join(dir, "packages/worker"), "Worker package should find its pyproject.toml");
});
});
// ============================================================================
// Additional Go tests
// ============================================================================
test("gopls: monorepo with multiple modules", async () => {
await withTempDir({
"go.work": "go 1.21\nuse (\n ./api\n ./worker\n)",
"api/go.mod": "module example.com/api",
"api/main.go": "package main",
"worker/go.mod": "module example.com/worker",
"worker/main.go": "package main",
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "gopls")!;
// With go.work present, all files should use workspace root
const apiRoot = server.findRoot(join(dir, "api/main.go"), dir);
const workerRoot = server.findRoot(join(dir, "worker/main.go"), dir);
assertEquals(apiRoot, dir, "API module should use go.work root");
assertEquals(workerRoot, dir, "Worker module should use go.work root");
});
});
test("gopls: nested cmd directory", async () => {
await withTempDir({
"go.mod": "module example.com/myapp",
"cmd/server/main.go": "package main",
"cmd/cli/main.go": "package main",
"internal/db/db.go": "package db",
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "gopls")!;
const serverRoot = server.findRoot(join(dir, "cmd/server/main.go"), dir);
const cliRoot = server.findRoot(join(dir, "cmd/cli/main.go"), dir);
const dbRoot = server.findRoot(join(dir, "internal/db/db.go"), dir);
assertEquals(serverRoot, dir, "cmd/server should find go.mod at root");
assertEquals(cliRoot, dir, "cmd/cli should find go.mod at root");
assertEquals(dbRoot, dir, "internal/db should find go.mod at root");
});
});
// ============================================================================
// Additional TypeScript tests
// ============================================================================
test("typescript: pnpm workspace", async () => {
await withTempDir({
"package.json": "{}",
"pnpm-workspace.yaml": "packages:\n - packages/*",
"packages/web/package.json": "{}",
"packages/web/src/App.tsx": "export const App = () => null;",
"packages/api/package.json": "{}",
"packages/api/src/index.ts": "export const handler = () => {};",
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "typescript")!;
const webRoot = server.findRoot(join(dir, "packages/web/src/App.tsx"), dir);
const apiRoot = server.findRoot(join(dir, "packages/api/src/index.ts"), dir);
assertEquals(webRoot, join(dir, "packages/web"), "Web package should find its package.json");
assertEquals(apiRoot, join(dir, "packages/api"), "API package should find its package.json");
});
});
test("typescript: returns undefined when no config files", async () => {
await withTempDir({
"script.ts": "const x = 1;",
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "typescript")!;
const root = server.findRoot(join(dir, "script.ts"), dir);
assertEquals(root, undefined, "Should return undefined when no package.json or tsconfig.json");
});
});
test("typescript: prefers nearest tsconfig over package.json", async () => {
await withTempDir({
"package.json": "{}",
"apps/web/tsconfig.json": "{}",
"apps/web/src/index.ts": "export const x = 1;",
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "typescript")!;
const root = server.findRoot(join(dir, "apps/web/src/index.ts"), dir);
// Should find tsconfig.json first (it's nearer than root package.json)
assertEquals(root, join(dir, "apps/web"), "Should find nearest config file");
});
});
// ============================================================================
// Additional Vue/Svelte tests
// ============================================================================
test("vue: Nuxt project", async () => {
await withTempDir({
"package.json": "{}",
"nuxt.config.ts": "export default {}",
"pages/index.vue": "<template></template>",
"components/Button.vue": "<template></template>",
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "vue")!;
const pagesRoot = server.findRoot(join(dir, "pages/index.vue"), dir);
const componentsRoot = server.findRoot(join(dir, "components/Button.vue"), dir);
assertEquals(pagesRoot, dir, "Pages should find root");
assertEquals(componentsRoot, dir, "Components should find root");
});
});
test("vue: returns undefined when no config", async () => {
await withTempDir({
"App.vue": "<template></template>",
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "vue")!;
const root = server.findRoot(join(dir, "App.vue"), dir);
assertEquals(root, undefined, "Should return undefined when no package.json or vite.config");
});
});
test("svelte: SvelteKit project", async () => {
await withTempDir({
"package.json": "{}",
"svelte.config.js": "export default {}",
"src/routes/+page.svelte": "<script></script>",
"src/lib/components/Button.svelte": "<script></script>",
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "svelte")!;
const routeRoot = server.findRoot(join(dir, "src/routes/+page.svelte"), dir);
const libRoot = server.findRoot(join(dir, "src/lib/components/Button.svelte"), dir);
assertEquals(routeRoot, dir, "Route should find root");
assertEquals(libRoot, dir, "Lib component should find root");
});
});
test("svelte: returns undefined when no config", async () => {
await withTempDir({
"App.svelte": "<script></script>",
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "svelte")!;
const root = server.findRoot(join(dir, "App.svelte"), dir);
assertEquals(root, undefined, "Should return undefined when no package.json or svelte.config.js");
});
});
// ============================================================================
// Stop boundary tests (findNearestFile respects cwd boundary)
// ============================================================================
test("stop boundary: does not search above cwd", async () => {
await withTempDir({
"package.json": "{}", // This is at root
"projects/myapp/src/index.ts": "export const x = 1;",
// Note: no package.json in projects/myapp
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "typescript")!;
// When cwd is set to projects/myapp, it should NOT find the root package.json
const projectDir = join(dir, "projects/myapp");
const root = server.findRoot(join(projectDir, "src/index.ts"), projectDir);
assertEquals(root, undefined, "Should not find package.json above cwd boundary");
});
});
test("stop boundary: finds marker at cwd level", async () => {
await withTempDir({
"projects/myapp/package.json": "{}",
"projects/myapp/src/index.ts": "export const x = 1;",
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "typescript")!;
const projectDir = join(dir, "projects/myapp");
const root = server.findRoot(join(projectDir, "src/index.ts"), projectDir);
assertEquals(root, projectDir, "Should find package.json at cwd level");
});
});
// ============================================================================
// Edge cases
// ============================================================================
test("edge: deeply nested file finds correct root", async () => {
await withTempDir({
"package.json": "{}",
"src/components/ui/buttons/primary/Button.tsx": "export const Button = () => null;",
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "typescript")!;
const root = server.findRoot(join(dir, "src/components/ui/buttons/primary/Button.tsx"), dir);
assertEquals(root, dir, "Should find root even for deeply nested files");
});
});
test("edge: file at root level finds root", async () => {
await withTempDir({
"package.json": "{}",
"index.ts": "console.log('root');",
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "typescript")!;
const root = server.findRoot(join(dir, "index.ts"), dir);
assertEquals(root, dir, "Should find root for file at root level");
});
});
test("edge: no marker files returns undefined", async () => {
await withTempDir({
"random.ts": "const x = 1;",
}, async (dir) => {
const server = LSP_SERVERS.find(s => s.id === "typescript")!;
const root = server.findRoot(join(dir, "random.ts"), dir);
assertEquals(root, undefined, "Should return undefined when no marker files");
});
});
// ============================================================================
// Run tests
// ============================================================================
async function runTests(): Promise<void> {
console.log("Running LSP tests...\n");
const results: TestResult[] = [];
let passed = 0;
let failed = 0;
for (const { name, fn } of tests) {
try {
await fn();
results.push({ name, passed: true });
console.log(` ${name}... ✓`);
passed++;
} catch (error) {
const errorMsg = error instanceof Error ? error.message : String(error);
results.push({ name, passed: false, error: errorMsg });
console.log(` ${name}... ✗`);
console.log(` Error: ${errorMsg}\n`);
failed++;
}
}
console.log(`\n${passed} passed, ${failed} failed`);
if (failed > 0) {
process.exit(1);
}
}
runTests();

View File

@@ -0,0 +1,13 @@
{
"compilerOptions": {
"target": "ES2022",
"module": "NodeNext",
"moduleResolution": "NodeNext",
"esModuleInterop": true,
"strict": true,
"skipLibCheck": true,
"noEmit": true,
"lib": ["ES2022"]
},
"include": ["*.ts"]
}

View File

@@ -0,0 +1,65 @@
import { wrapTextWithAnsi } from "@mariozechner/pi-tui";
const INLINE_NOTE_SEPARATOR = " — note: ";
const INLINE_EDIT_CURSOR = "▍";
export const INLINE_NOTE_WRAP_PADDING = 2;
function sanitizeNoteForInlineDisplay(rawNote: string): string {
return rawNote.replace(/[\r\n\t]/g, " ").replace(/[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]/g, "");
}
function truncateTextKeepingTail(text: string, maxLength: number): string {
if (maxLength <= 0) return "";
if (text.length <= maxLength) return text;
if (maxLength === 1) return "…";
return `${text.slice(-(maxLength - 1))}`;
}
function truncateTextKeepingHead(text: string, maxLength: number): string {
if (maxLength <= 0) return "";
if (text.length <= maxLength) return text;
if (maxLength === 1) return "…";
return `${text.slice(0, maxLength - 1)}`;
}
export function buildOptionLabelWithInlineNote(
baseOptionLabel: string,
rawNote: string,
isEditingNote: boolean,
maxInlineLabelLength?: number,
): string {
const sanitizedNote = sanitizeNoteForInlineDisplay(rawNote);
if (!isEditingNote && sanitizedNote.trim().length === 0) {
return baseOptionLabel;
}
const labelPrefix = `${baseOptionLabel}${INLINE_NOTE_SEPARATOR}`;
const inlineNote = isEditingNote ? `${sanitizedNote}${INLINE_EDIT_CURSOR}` : sanitizedNote.trim();
const inlineLabel = `${labelPrefix}${inlineNote}`;
if (maxInlineLabelLength == null) {
return inlineLabel;
}
return isEditingNote
? truncateTextKeepingTail(inlineLabel, maxInlineLabelLength)
: truncateTextKeepingHead(inlineLabel, maxInlineLabelLength);
}
export function buildWrappedOptionLabelWithInlineNote(
baseOptionLabel: string,
rawNote: string,
isEditingNote: boolean,
maxInlineLabelLength: number,
wrapPadding = INLINE_NOTE_WRAP_PADDING,
): string[] {
const inlineLabel = buildOptionLabelWithInlineNote(baseOptionLabel, rawNote, isEditingNote);
const sanitizedWrapPadding = Number.isFinite(wrapPadding) ? Math.max(0, Math.floor(wrapPadding)) : 0;
const sanitizedMaxInlineLabelLength = Number.isFinite(maxInlineLabelLength)
? Math.max(1, Math.floor(maxInlineLabelLength))
: 1;
const wrapWidth = Math.max(1, sanitizedMaxInlineLabelLength - sanitizedWrapPadding);
const wrappedLines = wrapTextWithAnsi(inlineLabel, wrapWidth);
return wrappedLines.length > 0 ? wrappedLines : [""];
}

View File

@@ -0,0 +1,223 @@
import type { ExtensionUIContext } from "@mariozechner/pi-coding-agent";
import { Editor, type EditorTheme, Key, matchesKey, truncateToWidth, visibleWidth, wrapTextWithAnsi } from "@mariozechner/pi-tui";
import {
OTHER_OPTION,
appendRecommendedTagToOptionLabels,
buildSingleSelectionResult,
type AskOption,
type AskSelection,
} from "./ask-logic";
import { INLINE_NOTE_WRAP_PADDING, buildWrappedOptionLabelWithInlineNote } from "./ask-inline-note";
interface SingleQuestionInput {
question: string;
options: AskOption[];
recommended?: number;
}
interface InlineSelectionResult {
cancelled: boolean;
selectedOption?: string;
note?: string;
}
function resolveInitialCursorIndexFromRecommendedOption(
recommendedOptionIndex: number | undefined,
optionCount: number,
): number {
if (recommendedOptionIndex == null) return 0;
if (recommendedOptionIndex < 0 || recommendedOptionIndex >= optionCount) return 0;
return recommendedOptionIndex;
}
export async function askSingleQuestionWithInlineNote(
ui: ExtensionUIContext,
questionInput: SingleQuestionInput,
): Promise<AskSelection> {
const baseOptionLabels = questionInput.options.map((option) => option.label);
const optionLabelsWithRecommendedTag = appendRecommendedTagToOptionLabels(
baseOptionLabels,
questionInput.recommended,
);
const selectableOptionLabels = [...optionLabelsWithRecommendedTag, OTHER_OPTION];
const initialCursorIndex = resolveInitialCursorIndexFromRecommendedOption(
questionInput.recommended,
optionLabelsWithRecommendedTag.length,
);
const result = await ui.custom<InlineSelectionResult>((tui, theme, _keybindings, done) => {
let cursorOptionIndex = initialCursorIndex;
let isNoteEditorOpen = false;
let cachedRenderedLines: string[] | undefined;
const noteByOptionIndex = new Map<number, string>();
const editorTheme: EditorTheme = {
borderColor: (text) => theme.fg("accent", text),
selectList: {
selectedPrefix: (text) => theme.fg("accent", text),
selectedText: (text) => theme.fg("accent", text),
description: (text) => theme.fg("muted", text),
scrollInfo: (text) => theme.fg("dim", text),
noMatch: (text) => theme.fg("warning", text),
},
};
const noteEditor = new Editor(tui, editorTheme);
const requestUiRerender = () => {
cachedRenderedLines = undefined;
tui.requestRender();
};
const getRawNoteForOption = (optionIndex: number): string => noteByOptionIndex.get(optionIndex) ?? "";
const getTrimmedNoteForOption = (optionIndex: number): string => getRawNoteForOption(optionIndex).trim();
const loadCurrentNoteIntoEditor = () => {
noteEditor.setText(getRawNoteForOption(cursorOptionIndex));
};
const saveCurrentNoteFromEditor = (value: string) => {
noteByOptionIndex.set(cursorOptionIndex, value);
};
const submitCurrentSelection = (selectedOptionLabel: string, note: string) => {
done({
cancelled: false,
selectedOption: selectedOptionLabel,
note,
});
};
noteEditor.onChange = (value) => {
saveCurrentNoteFromEditor(value);
requestUiRerender();
};
noteEditor.onSubmit = (value) => {
saveCurrentNoteFromEditor(value);
const selectedOptionLabel = selectableOptionLabels[cursorOptionIndex];
const trimmedNote = value.trim();
if (selectedOptionLabel === OTHER_OPTION && !trimmedNote) {
requestUiRerender();
return;
}
submitCurrentSelection(selectedOptionLabel, trimmedNote);
};
const render = (width: number): string[] => {
if (cachedRenderedLines) return cachedRenderedLines;
const renderedLines: string[] = [];
const addLine = (line: string) => renderedLines.push(truncateToWidth(line, width));
addLine(theme.fg("accent", "─".repeat(width)));
for (const questionLine of wrapTextWithAnsi(questionInput.question, Math.max(1, width - 1))) {
addLine(` ${theme.fg("text", questionLine)}`);
}
renderedLines.push("");
for (let optionIndex = 0; optionIndex < selectableOptionLabels.length; optionIndex++) {
const optionLabel = selectableOptionLabels[optionIndex];
const isCursorOption = optionIndex === cursorOptionIndex;
const isEditingThisOption = isNoteEditorOpen && isCursorOption;
const cursorPrefixText = isCursorOption ? "→ " : " ";
const cursorPrefix = isCursorOption ? theme.fg("accent", cursorPrefixText) : cursorPrefixText;
const bullet = isCursorOption ? "●" : "○";
const markerText = `${bullet} `;
const optionColor = isCursorOption ? "accent" : "text";
const prefixWidth = visibleWidth(cursorPrefixText) + visibleWidth(markerText);
const wrappedInlineLabelLines = buildWrappedOptionLabelWithInlineNote(
optionLabel,
getRawNoteForOption(optionIndex),
isEditingThisOption,
Math.max(1, width - prefixWidth),
INLINE_NOTE_WRAP_PADDING,
);
const continuationPrefix = " ".repeat(prefixWidth);
addLine(`${cursorPrefix}${theme.fg(optionColor, `${markerText}${wrappedInlineLabelLines[0] ?? ""}`)}`);
for (const wrappedLine of wrappedInlineLabelLines.slice(1)) {
addLine(`${continuationPrefix}${theme.fg(optionColor, wrappedLine)}`);
}
}
renderedLines.push("");
if (isNoteEditorOpen) {
addLine(theme.fg("dim", " Typing note inline • Enter submit • Tab/Esc stop editing"));
} else if (getTrimmedNoteForOption(cursorOptionIndex).length > 0) {
addLine(theme.fg("dim", " ↑↓ move • Enter submit • Tab edit note • Esc cancel"));
} else {
addLine(theme.fg("dim", " ↑↓ move • Enter submit • Tab add note • Esc cancel"));
}
addLine(theme.fg("accent", "─".repeat(width)));
cachedRenderedLines = renderedLines;
return renderedLines;
};
const handleInput = (data: string) => {
if (isNoteEditorOpen) {
if (matchesKey(data, Key.tab) || matchesKey(data, Key.escape)) {
isNoteEditorOpen = false;
requestUiRerender();
return;
}
noteEditor.handleInput(data);
requestUiRerender();
return;
}
if (matchesKey(data, Key.up)) {
cursorOptionIndex = Math.max(0, cursorOptionIndex - 1);
requestUiRerender();
return;
}
if (matchesKey(data, Key.down)) {
cursorOptionIndex = Math.min(selectableOptionLabels.length - 1, cursorOptionIndex + 1);
requestUiRerender();
return;
}
if (matchesKey(data, Key.tab)) {
isNoteEditorOpen = true;
loadCurrentNoteIntoEditor();
requestUiRerender();
return;
}
if (matchesKey(data, Key.enter)) {
const selectedOptionLabel = selectableOptionLabels[cursorOptionIndex];
const trimmedNote = getTrimmedNoteForOption(cursorOptionIndex);
if (selectedOptionLabel === OTHER_OPTION && !trimmedNote) {
isNoteEditorOpen = true;
loadCurrentNoteIntoEditor();
requestUiRerender();
return;
}
submitCurrentSelection(selectedOptionLabel, trimmedNote);
return;
}
if (matchesKey(data, Key.escape)) {
done({ cancelled: true });
}
};
return {
render,
invalidate: () => {
cachedRenderedLines = undefined;
},
handleInput,
};
});
if (result.cancelled || !result.selectedOption) {
return { selectedOptions: [] };
}
return buildSingleSelectionResult(result.selectedOption, result.note);
}

View File

@@ -0,0 +1,98 @@
export const OTHER_OPTION = "Other (type your own)";
const RECOMMENDED_OPTION_TAG = " (Recommended)";
export interface AskOption {
label: string;
}
export interface AskQuestion {
id: string;
question: string;
options: AskOption[];
multi?: boolean;
recommended?: number;
}
export interface AskSelection {
selectedOptions: string[];
customInput?: string;
}
export function appendRecommendedTagToOptionLabels(
optionLabels: string[],
recommendedOptionIndex?: number,
): string[] {
if (
recommendedOptionIndex == null ||
recommendedOptionIndex < 0 ||
recommendedOptionIndex >= optionLabels.length
) {
return optionLabels;
}
return optionLabels.map((optionLabel, optionIndex) => {
if (optionIndex !== recommendedOptionIndex) return optionLabel;
if (optionLabel.endsWith(RECOMMENDED_OPTION_TAG)) return optionLabel;
return `${optionLabel}${RECOMMENDED_OPTION_TAG}`;
});
}
function removeRecommendedTagFromOptionLabel(optionLabel: string): string {
if (!optionLabel.endsWith(RECOMMENDED_OPTION_TAG)) {
return optionLabel;
}
return optionLabel.slice(0, -RECOMMENDED_OPTION_TAG.length);
}
export function buildSingleSelectionResult(selectedOptionLabel: string, note?: string): AskSelection {
const normalizedSelectedOption = removeRecommendedTagFromOptionLabel(selectedOptionLabel);
const normalizedNote = note?.trim();
if (normalizedSelectedOption === OTHER_OPTION) {
if (normalizedNote) {
return { selectedOptions: [], customInput: normalizedNote };
}
return { selectedOptions: [] };
}
if (normalizedNote) {
return { selectedOptions: [`${normalizedSelectedOption} - ${normalizedNote}`] };
}
return { selectedOptions: [normalizedSelectedOption] };
}
export function buildMultiSelectionResult(
optionLabels: string[],
selectedOptionIndexes: number[],
optionNotes: string[],
otherOptionIndex: number,
): AskSelection {
const selectedOptionSet = new Set(selectedOptionIndexes);
const selectedOptions: string[] = [];
let customInput: string | undefined;
for (let optionIndex = 0; optionIndex < optionLabels.length; optionIndex++) {
if (!selectedOptionSet.has(optionIndex)) continue;
const optionLabel = removeRecommendedTagFromOptionLabel(optionLabels[optionIndex]);
const optionNote = optionNotes[optionIndex]?.trim();
if (optionIndex === otherOptionIndex) {
if (optionNote) customInput = optionNote;
continue;
}
if (optionNote) {
selectedOptions.push(`${optionLabel} - ${optionNote}`);
} else {
selectedOptions.push(optionLabel);
}
}
if (customInput) {
return { selectedOptions, customInput };
}
return { selectedOptions };
}

View File

@@ -0,0 +1,514 @@
import type { ExtensionUIContext } from "@mariozechner/pi-coding-agent";
import { Editor, type EditorTheme, Key, matchesKey, truncateToWidth, visibleWidth, wrapTextWithAnsi } from "@mariozechner/pi-tui";
import {
OTHER_OPTION,
appendRecommendedTagToOptionLabels,
buildMultiSelectionResult,
buildSingleSelectionResult,
type AskQuestion,
type AskSelection,
} from "./ask-logic";
import { INLINE_NOTE_WRAP_PADDING, buildWrappedOptionLabelWithInlineNote } from "./ask-inline-note";
interface PreparedQuestion {
id: string;
question: string;
options: string[];
tabLabel: string;
multi: boolean;
otherOptionIndex: number;
}
interface TabsUIState {
cancelled: boolean;
selectedOptionIndexesByQuestion: number[][];
noteByQuestionByOption: string[][];
}
export function formatSelectionForSubmitReview(selection: AskSelection, isMulti: boolean): string {
const hasSelectedOptions = selection.selectedOptions.length > 0;
const hasCustomInput = Boolean(selection.customInput);
if (hasSelectedOptions && hasCustomInput) {
const selectedPart = isMulti
? `[${selection.selectedOptions.join(", ")}]`
: selection.selectedOptions[0];
return `${selectedPart} + Other: ${selection.customInput}`;
}
if (hasCustomInput) {
return `Other: ${selection.customInput}`;
}
if (hasSelectedOptions) {
return isMulti ? `[${selection.selectedOptions.join(", ")}]` : selection.selectedOptions[0];
}
return "(not answered)";
}
function clampIndex(index: number | undefined, maxExclusive: number): number {
if (index == null || Number.isNaN(index) || maxExclusive <= 0) return 0;
if (index < 0) return 0;
if (index >= maxExclusive) return maxExclusive - 1;
return index;
}
function normalizeTabLabel(id: string, fallback: string): string {
const normalized = id.trim().replace(/[_-]+/g, " ");
return normalized.length > 0 ? normalized : fallback;
}
function buildSelectionForQuestion(
question: PreparedQuestion,
selectedOptionIndexes: number[],
noteByOptionIndex: string[],
): AskSelection {
if (selectedOptionIndexes.length === 0) {
return { selectedOptions: [] };
}
if (question.multi) {
return buildMultiSelectionResult(question.options, selectedOptionIndexes, noteByOptionIndex, question.otherOptionIndex);
}
const selectedOptionIndex = selectedOptionIndexes[0];
const selectedOptionLabel = question.options[selectedOptionIndex] ?? OTHER_OPTION;
const note = noteByOptionIndex[selectedOptionIndex] ?? "";
return buildSingleSelectionResult(selectedOptionLabel, note);
}
function isQuestionSelectionValid(
question: PreparedQuestion,
selectedOptionIndexes: number[],
noteByOptionIndex: string[],
): boolean {
if (selectedOptionIndexes.length === 0) return false;
if (!selectedOptionIndexes.includes(question.otherOptionIndex)) return true;
const otherNote = noteByOptionIndex[question.otherOptionIndex]?.trim() ?? "";
return otherNote.length > 0;
}
function createTabsUiStateSnapshot(
cancelled: boolean,
selectedOptionIndexesByQuestion: number[][],
noteByQuestionByOption: string[][],
): TabsUIState {
return {
cancelled,
selectedOptionIndexesByQuestion: selectedOptionIndexesByQuestion.map((indexes) => [...indexes]),
noteByQuestionByOption: noteByQuestionByOption.map((notes) => [...notes]),
};
}
function addIndexToSelection(selectedOptionIndexes: number[], optionIndex: number): number[] {
if (selectedOptionIndexes.includes(optionIndex)) return selectedOptionIndexes;
return [...selectedOptionIndexes, optionIndex].sort((a, b) => a - b);
}
function removeIndexFromSelection(selectedOptionIndexes: number[], optionIndex: number): number[] {
return selectedOptionIndexes.filter((index) => index !== optionIndex);
}
export async function askQuestionsWithTabs(
ui: ExtensionUIContext,
questions: AskQuestion[],
): Promise<{ cancelled: boolean; selections: AskSelection[] }> {
const preparedQuestions: PreparedQuestion[] = questions.map((question, questionIndex) => {
const baseOptionLabels = question.options.map((option) => option.label);
const optionLabels = [...appendRecommendedTagToOptionLabels(baseOptionLabels, question.recommended), OTHER_OPTION];
return {
id: question.id,
question: question.question,
options: optionLabels,
tabLabel: normalizeTabLabel(question.id, `Q${questionIndex + 1}`),
multi: question.multi === true,
otherOptionIndex: optionLabels.length - 1,
};
});
const initialCursorOptionIndexByQuestion = preparedQuestions.map((preparedQuestion, questionIndex) =>
clampIndex(questions[questionIndex].recommended, preparedQuestion.options.length),
);
const result = await ui.custom<TabsUIState>((tui, theme, _keybindings, done) => {
let activeTabIndex = 0;
let isNoteEditorOpen = false;
let cachedRenderedLines: string[] | undefined;
const cursorOptionIndexByQuestion = [...initialCursorOptionIndexByQuestion];
const selectedOptionIndexesByQuestion = preparedQuestions.map(() => [] as number[]);
const noteByQuestionByOption = preparedQuestions.map((preparedQuestion) =>
Array(preparedQuestion.options.length).fill("") as string[],
);
const editorTheme: EditorTheme = {
borderColor: (text) => theme.fg("accent", text),
selectList: {
selectedPrefix: (text) => theme.fg("accent", text),
selectedText: (text) => theme.fg("accent", text),
description: (text) => theme.fg("muted", text),
scrollInfo: (text) => theme.fg("dim", text),
noMatch: (text) => theme.fg("warning", text),
},
};
const noteEditor = new Editor(tui, editorTheme);
const submitTabIndex = preparedQuestions.length;
const requestUiRerender = () => {
cachedRenderedLines = undefined;
tui.requestRender();
};
const getActiveQuestionIndex = (): number | null => {
if (activeTabIndex >= preparedQuestions.length) return null;
return activeTabIndex;
};
const getQuestionNote = (questionIndex: number, optionIndex: number): string =>
noteByQuestionByOption[questionIndex]?.[optionIndex] ?? "";
const getTrimmedQuestionNote = (questionIndex: number, optionIndex: number): string =>
getQuestionNote(questionIndex, optionIndex).trim();
const isAllQuestionSelectionsValid = (): boolean =>
preparedQuestions.every((preparedQuestion, questionIndex) =>
isQuestionSelectionValid(
preparedQuestion,
selectedOptionIndexesByQuestion[questionIndex],
noteByQuestionByOption[questionIndex],
),
);
const openNoteEditorForActiveOption = () => {
const questionIndex = getActiveQuestionIndex();
if (questionIndex == null) return;
isNoteEditorOpen = true;
const optionIndex = cursorOptionIndexByQuestion[questionIndex];
noteEditor.setText(getQuestionNote(questionIndex, optionIndex));
requestUiRerender();
};
const advanceToNextTabOrSubmit = () => {
activeTabIndex = Math.min(submitTabIndex, activeTabIndex + 1);
};
noteEditor.onChange = (value) => {
const questionIndex = getActiveQuestionIndex();
if (questionIndex == null) return;
const optionIndex = cursorOptionIndexByQuestion[questionIndex];
noteByQuestionByOption[questionIndex][optionIndex] = value;
requestUiRerender();
};
noteEditor.onSubmit = (value) => {
const questionIndex = getActiveQuestionIndex();
if (questionIndex == null) return;
const preparedQuestion = preparedQuestions[questionIndex];
const optionIndex = cursorOptionIndexByQuestion[questionIndex];
noteByQuestionByOption[questionIndex][optionIndex] = value;
const trimmedNote = value.trim();
if (preparedQuestion.multi) {
if (trimmedNote.length > 0) {
selectedOptionIndexesByQuestion[questionIndex] = addIndexToSelection(
selectedOptionIndexesByQuestion[questionIndex],
optionIndex,
);
}
if (optionIndex === preparedQuestion.otherOptionIndex && trimmedNote.length === 0) {
requestUiRerender();
return;
}
isNoteEditorOpen = false;
requestUiRerender();
return;
}
selectedOptionIndexesByQuestion[questionIndex] = [optionIndex];
if (optionIndex === preparedQuestion.otherOptionIndex && trimmedNote.length === 0) {
requestUiRerender();
return;
}
isNoteEditorOpen = false;
advanceToNextTabOrSubmit();
requestUiRerender();
};
const renderTabs = (): string => {
const tabParts: string[] = ["← "];
for (let questionIndex = 0; questionIndex < preparedQuestions.length; questionIndex++) {
const preparedQuestion = preparedQuestions[questionIndex];
const isActiveTab = questionIndex === activeTabIndex;
const isQuestionValid = isQuestionSelectionValid(
preparedQuestion,
selectedOptionIndexesByQuestion[questionIndex],
noteByQuestionByOption[questionIndex],
);
const statusIcon = isQuestionValid ? "■" : "□";
const tabLabel = ` ${statusIcon} ${preparedQuestion.tabLabel} `;
const styledTabLabel = isActiveTab
? theme.bg("selectedBg", theme.fg("text", tabLabel))
: theme.fg(isQuestionValid ? "success" : "muted", tabLabel);
tabParts.push(`${styledTabLabel} `);
}
const isSubmitTabActive = activeTabIndex === submitTabIndex;
const canSubmit = isAllQuestionSelectionsValid();
const submitLabel = " ✓ Submit ";
const styledSubmitLabel = isSubmitTabActive
? theme.bg("selectedBg", theme.fg("text", submitLabel))
: theme.fg(canSubmit ? "success" : "dim", submitLabel);
tabParts.push(`${styledSubmitLabel}`);
return tabParts.join("");
};
const renderSubmitTab = (width: number, renderedLines: string[]): void => {
const addLine = (line: string) => renderedLines.push(truncateToWidth(line, width));
addLine(theme.fg("accent", theme.bold(" Review answers")));
renderedLines.push("");
for (let questionIndex = 0; questionIndex < preparedQuestions.length; questionIndex++) {
const preparedQuestion = preparedQuestions[questionIndex];
const selection = buildSelectionForQuestion(
preparedQuestion,
selectedOptionIndexesByQuestion[questionIndex],
noteByQuestionByOption[questionIndex],
);
const value = formatSelectionForSubmitReview(selection, preparedQuestion.multi);
const isValid = isQuestionSelectionValid(
preparedQuestion,
selectedOptionIndexesByQuestion[questionIndex],
noteByQuestionByOption[questionIndex],
);
const statusIcon = isValid ? theme.fg("success", "●") : theme.fg("warning", "○");
addLine(` ${statusIcon} ${theme.fg("muted", `${preparedQuestion.tabLabel}:`)} ${theme.fg("text", value)}`);
}
renderedLines.push("");
if (isAllQuestionSelectionsValid()) {
addLine(theme.fg("success", " Press Enter to submit"));
} else {
const missingQuestions = preparedQuestions
.filter((preparedQuestion, questionIndex) =>
!isQuestionSelectionValid(
preparedQuestion,
selectedOptionIndexesByQuestion[questionIndex],
noteByQuestionByOption[questionIndex],
),
)
.map((preparedQuestion) => preparedQuestion.tabLabel)
.join(", ");
addLine(theme.fg("warning", ` Complete required answers: ${missingQuestions}`));
}
addLine(theme.fg("dim", " ←/→ switch tabs • Esc cancel"));
};
const renderQuestionTab = (width: number, renderedLines: string[], questionIndex: number): void => {
const addLine = (line: string) => renderedLines.push(truncateToWidth(line, width));
const preparedQuestion = preparedQuestions[questionIndex];
const cursorOptionIndex = cursorOptionIndexByQuestion[questionIndex];
const selectedOptionIndexes = selectedOptionIndexesByQuestion[questionIndex];
for (const questionLine of wrapTextWithAnsi(preparedQuestion.question, Math.max(1, width - 1))) {
addLine(` ${theme.fg("text", questionLine)}`);
}
renderedLines.push("");
for (let optionIndex = 0; optionIndex < preparedQuestion.options.length; optionIndex++) {
const optionLabel = preparedQuestion.options[optionIndex];
const isCursorOption = optionIndex === cursorOptionIndex;
const isOptionSelected = selectedOptionIndexes.includes(optionIndex);
const isEditingThisOption = isNoteEditorOpen && isCursorOption;
const cursorPrefixText = isCursorOption ? "→ " : " ";
const cursorPrefix = isCursorOption ? theme.fg("accent", cursorPrefixText) : cursorPrefixText;
const markerText = preparedQuestion.multi
? `${isOptionSelected ? "[x]" : "[ ]"} `
: `${isOptionSelected ? "●" : "○"} `;
const optionColor = isCursorOption ? "accent" : isOptionSelected ? "success" : "text";
const prefixWidth = visibleWidth(cursorPrefixText) + visibleWidth(markerText);
const wrappedInlineLabelLines = buildWrappedOptionLabelWithInlineNote(
optionLabel,
getQuestionNote(questionIndex, optionIndex),
isEditingThisOption,
Math.max(1, width - prefixWidth),
INLINE_NOTE_WRAP_PADDING,
);
const continuationPrefix = " ".repeat(prefixWidth);
addLine(`${cursorPrefix}${theme.fg(optionColor, `${markerText}${wrappedInlineLabelLines[0] ?? ""}`)}`);
for (const wrappedLine of wrappedInlineLabelLines.slice(1)) {
addLine(`${continuationPrefix}${theme.fg(optionColor, wrappedLine)}`);
}
}
renderedLines.push("");
if (isNoteEditorOpen) {
addLine(theme.fg("dim", " Typing note inline • Enter save note • Tab/Esc stop editing"));
} else {
if (preparedQuestion.multi) {
addLine(
theme.fg(
"dim",
" ↑↓ move • Enter toggle/select • Tab add note • ←/→ switch tabs • Esc cancel",
),
);
} else {
addLine(
theme.fg("dim", " ↑↓ move • Enter select • Tab add note • ←/→ switch tabs • Esc cancel"),
);
}
}
};
const render = (width: number): string[] => {
if (cachedRenderedLines) return cachedRenderedLines;
const renderedLines: string[] = [];
const addLine = (line: string) => renderedLines.push(truncateToWidth(line, width));
addLine(theme.fg("accent", "─".repeat(width)));
addLine(` ${renderTabs()}`);
renderedLines.push("");
if (activeTabIndex === submitTabIndex) {
renderSubmitTab(width, renderedLines);
} else {
renderQuestionTab(width, renderedLines, activeTabIndex);
}
addLine(theme.fg("accent", "─".repeat(width)));
cachedRenderedLines = renderedLines;
return renderedLines;
};
const handleInput = (data: string) => {
if (isNoteEditorOpen) {
if (matchesKey(data, Key.tab) || matchesKey(data, Key.escape)) {
isNoteEditorOpen = false;
requestUiRerender();
return;
}
noteEditor.handleInput(data);
requestUiRerender();
return;
}
if (matchesKey(data, Key.left)) {
activeTabIndex = (activeTabIndex - 1 + preparedQuestions.length + 1) % (preparedQuestions.length + 1);
requestUiRerender();
return;
}
if (matchesKey(data, Key.right)) {
activeTabIndex = (activeTabIndex + 1) % (preparedQuestions.length + 1);
requestUiRerender();
return;
}
if (activeTabIndex === submitTabIndex) {
if (matchesKey(data, Key.enter) && isAllQuestionSelectionsValid()) {
done(createTabsUiStateSnapshot(false, selectedOptionIndexesByQuestion, noteByQuestionByOption));
return;
}
if (matchesKey(data, Key.escape)) {
done(createTabsUiStateSnapshot(true, selectedOptionIndexesByQuestion, noteByQuestionByOption));
}
return;
}
const questionIndex = activeTabIndex;
const preparedQuestion = preparedQuestions[questionIndex];
if (matchesKey(data, Key.up)) {
cursorOptionIndexByQuestion[questionIndex] = Math.max(0, cursorOptionIndexByQuestion[questionIndex] - 1);
requestUiRerender();
return;
}
if (matchesKey(data, Key.down)) {
cursorOptionIndexByQuestion[questionIndex] = Math.min(
preparedQuestion.options.length - 1,
cursorOptionIndexByQuestion[questionIndex] + 1,
);
requestUiRerender();
return;
}
if (matchesKey(data, Key.tab)) {
openNoteEditorForActiveOption();
return;
}
if (matchesKey(data, Key.enter)) {
const cursorOptionIndex = cursorOptionIndexByQuestion[questionIndex];
if (preparedQuestion.multi) {
const currentlySelected = selectedOptionIndexesByQuestion[questionIndex];
if (currentlySelected.includes(cursorOptionIndex)) {
selectedOptionIndexesByQuestion[questionIndex] = removeIndexFromSelection(currentlySelected, cursorOptionIndex);
} else {
selectedOptionIndexesByQuestion[questionIndex] = addIndexToSelection(currentlySelected, cursorOptionIndex);
}
if (
cursorOptionIndex === preparedQuestion.otherOptionIndex &&
selectedOptionIndexesByQuestion[questionIndex].includes(cursorOptionIndex) &&
getTrimmedQuestionNote(questionIndex, cursorOptionIndex).length === 0
) {
openNoteEditorForActiveOption();
return;
}
requestUiRerender();
return;
}
selectedOptionIndexesByQuestion[questionIndex] = [cursorOptionIndex];
if (
cursorOptionIndex === preparedQuestion.otherOptionIndex &&
getTrimmedQuestionNote(questionIndex, cursorOptionIndex).length === 0
) {
openNoteEditorForActiveOption();
return;
}
advanceToNextTabOrSubmit();
requestUiRerender();
return;
}
if (matchesKey(data, Key.escape)) {
done(createTabsUiStateSnapshot(true, selectedOptionIndexesByQuestion, noteByQuestionByOption));
}
};
return {
render,
invalidate: () => {
cachedRenderedLines = undefined;
},
handleInput,
};
});
if (result.cancelled) {
return {
cancelled: true,
selections: preparedQuestions.map(() => ({ selectedOptions: [] } satisfies AskSelection)),
};
}
const selections = preparedQuestions.map((preparedQuestion, questionIndex) =>
buildSelectionForQuestion(
preparedQuestion,
result.selectedOptionIndexesByQuestion[questionIndex] ?? [],
result.noteByQuestionByOption[questionIndex] ?? Array(preparedQuestion.options.length).fill(""),
),
);
return { cancelled: result.cancelled, selections };
}

View File

@@ -0,0 +1,237 @@
import type { ExtensionAPI } from "@mariozechner/pi-coding-agent";
import { Type, type Static } from "@sinclair/typebox";
import { OTHER_OPTION, type AskQuestion } from "./ask-logic";
import { askSingleQuestionWithInlineNote } from "./ask-inline-ui";
import { askQuestionsWithTabs } from "./ask-tabs-ui";
const OptionItemSchema = Type.Object({
label: Type.String({ description: "Display label" }),
});
const QuestionItemSchema = Type.Object({
id: Type.String({ description: "Question id (e.g. auth, cache, priority)" }),
question: Type.String({ description: "Question text" }),
options: Type.Array(OptionItemSchema, {
description: "Available options. Do not include 'Other'.",
minItems: 1,
}),
multi: Type.Optional(Type.Boolean({ description: "Allow multi-select" })),
recommended: Type.Optional(
Type.Number({ description: "0-indexed recommended option. '(Recommended)' is shown automatically." }),
),
});
const AskParamsSchema = Type.Object({
questions: Type.Array(QuestionItemSchema, { description: "Questions to ask", minItems: 1 }),
});
type AskParams = Static<typeof AskParamsSchema>;
interface QuestionResult {
id: string;
question: string;
options: string[];
multi: boolean;
selectedOptions: string[];
customInput?: string;
}
interface AskToolDetails {
id?: string;
question?: string;
options?: string[];
multi?: boolean;
selectedOptions?: string[];
customInput?: string;
results?: QuestionResult[];
}
function sanitizeForSessionText(value: string): string {
return value
.replace(/[\r\n\t]/g, " ")
.replace(/[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]/g, "")
.replace(/\s{2,}/g, " ")
.trim();
}
function sanitizeOptionForSessionText(option: string): string {
const sanitizedOption = sanitizeForSessionText(option);
return sanitizedOption.length > 0 ? sanitizedOption : "(empty option)";
}
function toSessionSafeQuestionResult(result: QuestionResult): QuestionResult {
const selectedOptions = result.selectedOptions
.map((selectedOption) => sanitizeForSessionText(selectedOption))
.filter((selectedOption) => selectedOption.length > 0);
const rawCustomInput = result.customInput;
const customInput = rawCustomInput == null ? undefined : sanitizeForSessionText(rawCustomInput);
return {
id: sanitizeForSessionText(result.id) || "(unknown)",
question: sanitizeForSessionText(result.question) || "(empty question)",
options: result.options.map(sanitizeOptionForSessionText),
multi: result.multi,
selectedOptions,
customInput: customInput && customInput.length > 0 ? customInput : undefined,
};
}
function formatSelectionForSummary(result: QuestionResult): string {
const hasSelectedOptions = result.selectedOptions.length > 0;
const hasCustomInput = Boolean(result.customInput);
if (!hasSelectedOptions && !hasCustomInput) {
return "(cancelled)";
}
if (hasSelectedOptions && hasCustomInput) {
const selectedPart = result.multi
? `[${result.selectedOptions.join(", ")}]`
: result.selectedOptions[0];
return `${selectedPart} + Other: "${result.customInput}"`;
}
if (hasCustomInput) {
return `"${result.customInput}"`;
}
if (result.multi) {
return `[${result.selectedOptions.join(", ")}]`;
}
return result.selectedOptions[0];
}
function formatQuestionResult(result: QuestionResult): string {
return `${result.id}: ${formatSelectionForSummary(result)}`;
}
function formatQuestionContext(result: QuestionResult, questionIndex: number): string {
const lines: string[] = [
`Question ${questionIndex + 1} (${result.id})`,
`Prompt: ${result.question}`,
"Options:",
...result.options.map((option, optionIndex) => ` ${optionIndex + 1}. ${option}`),
"Response:",
];
const hasSelectedOptions = result.selectedOptions.length > 0;
const hasCustomInput = Boolean(result.customInput);
if (!hasSelectedOptions && !hasCustomInput) {
lines.push(" Selected: (cancelled)");
return lines.join("\n");
}
if (hasSelectedOptions) {
const selectedText = result.multi
? `[${result.selectedOptions.join(", ")}]`
: result.selectedOptions[0];
lines.push(` Selected: ${selectedText}`);
}
if (hasCustomInput) {
if (!hasSelectedOptions) {
lines.push(` Selected: ${OTHER_OPTION}`);
}
lines.push(` Custom input: ${result.customInput}`);
}
return lines.join("\n");
}
function buildAskSessionContent(results: QuestionResult[]): string {
const safeResults = results.map(toSessionSafeQuestionResult);
const summaryLines = safeResults.map(formatQuestionResult).join("\n");
const contextBlocks = safeResults.map((result, index) => formatQuestionContext(result, index)).join("\n\n");
return `User answers:\n${summaryLines}\n\nAnswer context:\n${contextBlocks}`;
}
const ASK_TOOL_DESCRIPTION = `
Ask the user for clarification when a choice materially affects the outcome.
- Use when multiple valid approaches have different trade-offs.
- Prefer 2-5 concise options.
- Use multi=true when multiple answers are valid.
- Use recommended=<index> (0-indexed) to mark the default option.
- You can ask multiple related questions in one call using questions[].
- Do NOT include an 'Other' option; UI adds it automatically.
`.trim();
export default function askExtension(pi: ExtensionAPI) {
pi.registerTool({
name: "ask",
label: "Ask",
description: ASK_TOOL_DESCRIPTION,
parameters: AskParamsSchema,
async execute(_toolCallId, params: AskParams, _signal, _onUpdate, ctx) {
if (!ctx.hasUI) {
return {
content: [{ type: "text", text: "Error: ask tool requires interactive mode" }],
details: {},
};
}
if (params.questions.length === 0) {
return {
content: [{ type: "text", text: "Error: questions must not be empty" }],
details: {},
};
}
if (params.questions.length === 1) {
const [q] = params.questions;
const selection = q.multi
? (await askQuestionsWithTabs(ctx.ui, [q as AskQuestion])).selections[0] ?? { selectedOptions: [] }
: await askSingleQuestionWithInlineNote(ctx.ui, q as AskQuestion);
const optionLabels = q.options.map((option) => option.label);
const result: QuestionResult = {
id: q.id,
question: q.question,
options: optionLabels,
multi: q.multi ?? false,
selectedOptions: selection.selectedOptions,
customInput: selection.customInput,
};
const details: AskToolDetails = {
id: q.id,
question: q.question,
options: optionLabels,
multi: q.multi ?? false,
selectedOptions: selection.selectedOptions,
customInput: selection.customInput,
results: [result],
};
return {
content: [{ type: "text", text: buildAskSessionContent([result]) }],
details,
};
}
const results: QuestionResult[] = [];
const tabResult = await askQuestionsWithTabs(ctx.ui, params.questions as AskQuestion[]);
for (let i = 0; i < params.questions.length; i++) {
const q = params.questions[i];
const selection = tabResult.selections[i] ?? { selectedOptions: [] };
results.push({
id: q.id,
question: q.question,
options: q.options.map((option) => option.label),
multi: q.multi ?? false,
selectedOptions: selection.selectedOptions,
customInput: selection.customInput,
});
}
return {
content: [{ type: "text", text: buildAskSessionContent(results) }],
details: { results } satisfies AskToolDetails,
};
},
});
}

View File

@@ -1,9 +1,12 @@
/**
* Usage Extension - Minimal API usage indicator for pi
*
* Shows Codex (OpenAI), Anthropic (Claude), Z.AI, and optionally
* Google Gemini CLI / Antigravity usage as color-coded percentages
* in the footer status bar.
* Polls Codex, Anthropic, Z.AI, Gemini CLI / Antigravity usage and exposes it
* via two channels:
* • pi.events "usage:update" — for other extensions (e.g. footer-display)
* • ctx.ui.setStatus("usage-bars", …) — formatted S/W braille bars
*
* Rendering / footer layout is handled by the separate footer-display extension.
*/
import { DynamicBorder, type ExtensionAPI } from "@mariozechner/pi-coding-agent";
@@ -37,20 +40,39 @@ import {
type UsageData,
} from "./core";
// Disk cache TTL for idle/background reads (session start, etc.)
const CACHE_TTL_MS = 15 * 60 * 1000;
// Shorter TTL for event-driven polls (after prompt submit / after turn end).
// With the shared disk cache, only one pi instance per ACTIVE_CACHE_TTL_MS window
// will actually hit the API — the rest will read from the cached result.
const ACTIVE_CACHE_TTL_MS = 3 * 60 * 1000;
// How often to re-poll while the model is actively streaming / running tools.
// Combined with the shared disk cache this means at most one HTTP request per
// STREAMING_POLL_INTERVAL_MS regardless of how many pi sessions are open.
const STREAMING_POLL_INTERVAL_MS = 2 * 60 * 1000;
const RATE_LIMITED_BACKOFF_MS = 60 * 60 * 1000;
const RATE_LIMITED_BACKOFF_MS = 60 * 60 * 1000; // 1 hour back-off after 429
const STATUS_KEY = "usage-bars";
// ---------------------------------------------------------------------------
// Braille gradient bar (⣀ ⣄ ⣤ ⣦ ⣶ ⣷ ⣿)
// ---------------------------------------------------------------------------
const BRAILLE_GRADIENT = "\u28C0\u28C4\u28E4\u28E6\u28F6\u28F7\u28FF";
const BRAILLE_EMPTY = "\u28C0";
const BAR_WIDTH = 5;
function renderBrailleBar(theme: any, value: number, width = BAR_WIDTH): string {
const v = clampPercent(value);
const levels = BRAILLE_GRADIENT.length - 1;
const totalSteps = width * levels;
const filledSteps = Math.round((v / 100) * totalSteps);
const full = Math.floor(filledSteps / levels);
const partial = filledSteps % levels;
const empty = width - full - (partial ? 1 : 0);
const color = colorForPercent(v);
const filled = BRAILLE_GRADIENT[BRAILLE_GRADIENT.length - 1]!.repeat(Math.max(0, full));
const partialChar = partial ? BRAILLE_GRADIENT[partial]! : "";
const emptyChars = BRAILLE_EMPTY.repeat(Math.max(0, empty));
return theme.fg(color, filled + partialChar) + theme.fg("dim", emptyChars);
}
function renderBrailleBarWide(theme: any, value: number): string {
return renderBrailleBar(theme, value, 12);
}
const PROVIDER_LABELS: Record<ProviderKey, string> = {
codex: "Codex",
claude: "Claude",
@@ -59,6 +81,9 @@ const PROVIDER_LABELS: Record<ProviderKey, string> = {
antigravity: "Antigravity",
};
// ---------------------------------------------------------------------------
// /usage command popup
// ---------------------------------------------------------------------------
interface SubscriptionItem {
name: string;
provider: ProviderKey;
@@ -81,14 +106,8 @@ class UsageSelectorComponent extends Container implements Focusable {
private fetchAllFn: () => Promise<UsageByProvider>;
private _focused = false;
get focused(): boolean {
return this._focused;
}
set focused(value: boolean) {
this._focused = value;
this.searchInput.focused = value;
}
get focused(): boolean { return this._focused; }
set focused(value: boolean) { this._focused = value; this.searchInput.focused = value; }
constructor(
tui: any,
@@ -106,19 +125,15 @@ class UsageSelectorComponent extends Container implements Focusable {
this.addChild(new DynamicBorder((s: string) => theme.fg("accent", s)));
this.addChild(new Spacer(1));
this.hintText = new Text(theme.fg("dim", "Fetching usage from all providers…"), 0, 0);
this.addChild(this.hintText);
this.addChild(new Spacer(1));
this.searchInput = new Input();
this.addChild(this.searchInput);
this.addChild(new Spacer(1));
this.listContainer = new Container();
this.addChild(this.listContainer);
this.addChild(new Spacer(1));
this.addChild(new DynamicBorder((s: string) => theme.fg("accent", s)));
this.fetchAllFn()
@@ -149,7 +164,6 @@ class UsageSelectorComponent extends Container implements Focusable {
{ key: "gemini", name: "Gemini" },
{ key: "antigravity", name: "Antigravity" },
];
this.allItems = [];
for (const p of providers) {
if (results[p.key] !== null) {
@@ -161,7 +175,6 @@ class UsageSelectorComponent extends Container implements Focusable {
});
}
}
this.filteredItems = this.allItems;
this.selectedIndex = Math.min(this.selectedIndex, Math.max(0, this.filteredItems.length - 1));
}
@@ -178,23 +191,12 @@ class UsageSelectorComponent extends Container implements Focusable {
this.selectedIndex = Math.min(this.selectedIndex, Math.max(0, this.filteredItems.length - 1));
}
private renderBar(pct: number, width = 16): string {
const value = clampPercent(pct);
const filled = Math.round((value / 100) * width);
const color = colorForPercent(value);
const full = "█".repeat(Math.max(0, filled));
const empty = "░".repeat(Math.max(0, width - filled));
return this.theme.fg(color, full) + this.theme.fg("dim", empty);
}
private renderItem(item: SubscriptionItem, isSelected: boolean) {
const t = this.theme;
const pointer = isSelected ? t.fg("accent", "→ ") : " ";
const activeBadge = item.isActive ? t.fg("success", " ✓") : "";
const name = isSelected ? t.fg("accent", t.bold(item.name)) : item.name;
this.listContainer.addChild(new Text(`${pointer}${name}${activeBadge}`, 0, 0));
const indent = " ";
if (!item.data) {
@@ -204,69 +206,45 @@ class UsageSelectorComponent extends Container implements Focusable {
} else {
const session = clampPercent(item.data.session);
const weekly = clampPercent(item.data.weekly);
const sessionReset = item.data.sessionResetsIn
? t.fg("dim", ` resets in ${item.data.sessionResetsIn}`)
: "";
? t.fg("dim", ` resets in ${item.data.sessionResetsIn}`) : "";
const weeklyReset = item.data.weeklyResetsIn
? t.fg("dim", ` resets in ${item.data.weeklyResetsIn}`)
: "";
? t.fg("dim", ` resets in ${item.data.weeklyResetsIn}`) : "";
this.listContainer.addChild(
new Text(
indent +
t.fg("muted", "Session ") +
this.renderBar(session) +
" " +
t.fg(colorForPercent(session), `${session}%`.padStart(4)) +
sessionReset,
0,
0,
),
);
this.listContainer.addChild(
new Text(
indent +
t.fg("muted", "Weekly ") +
this.renderBar(weekly) +
" " +
t.fg(colorForPercent(weekly), `${weekly}%`.padStart(4)) +
weeklyReset,
0,
0,
),
);
this.listContainer.addChild(new Text(
indent + t.fg("muted", "Session ") +
renderBrailleBarWide(t, session) + " " +
t.fg(colorForPercent(session), `${session}%`.padStart(4)) + sessionReset,
0, 0,
));
this.listContainer.addChild(new Text(
indent + t.fg("muted", "Weekly ") +
renderBrailleBarWide(t, weekly) + " " +
t.fg(colorForPercent(weekly), `${weekly}%`.padStart(4)) + weeklyReset,
0, 0,
));
if (typeof item.data.extraSpend === "number" && typeof item.data.extraLimit === "number") {
this.listContainer.addChild(
new Text(
indent +
t.fg("muted", "Extra ") +
t.fg("dim", `$${item.data.extraSpend.toFixed(2)} / $${item.data.extraLimit}`),
0,
0,
),
);
this.listContainer.addChild(new Text(
indent + t.fg("muted", "Extra ") +
t.fg("dim", `$${item.data.extraSpend.toFixed(2)} / $${item.data.extraLimit}`),
0, 0,
));
}
}
this.listContainer.addChild(new Spacer(1));
}
private updateList() {
this.listContainer.clear();
if (this.loading) {
this.listContainer.addChild(new Text(this.theme.fg("muted", " Loading…"), 0, 0));
return;
}
if (this.filteredItems.length === 0) {
this.listContainer.addChild(new Text(this.theme.fg("muted", " No matching providers"), 0, 0));
return;
}
for (let i = 0; i < this.filteredItems.length; i++) {
this.renderItem(this.filteredItems[i]!, i === this.selectedIndex);
}
@@ -274,91 +252,58 @@ class UsageSelectorComponent extends Container implements Focusable {
handleInput(keyData: string): void {
const kb = getEditorKeybindings();
if (kb.matches(keyData, "selectUp")) {
if (this.filteredItems.length === 0) return;
this.selectedIndex =
this.selectedIndex === 0 ? this.filteredItems.length - 1 : this.selectedIndex - 1;
this.updateList();
return;
this.selectedIndex = this.selectedIndex === 0 ? this.filteredItems.length - 1 : this.selectedIndex - 1;
this.updateList(); return;
}
if (kb.matches(keyData, "selectDown")) {
if (this.filteredItems.length === 0) return;
this.selectedIndex =
this.selectedIndex === this.filteredItems.length - 1 ? 0 : this.selectedIndex + 1;
this.updateList();
return;
this.selectedIndex = this.selectedIndex === this.filteredItems.length - 1 ? 0 : this.selectedIndex + 1;
this.updateList(); return;
}
if (kb.matches(keyData, "selectCancel") || kb.matches(keyData, "selectConfirm")) {
this.onCancelCallback();
return;
this.onCancelCallback(); return;
}
this.searchInput.handleInput(keyData);
this.filterItems(this.searchInput.getValue());
this.updateList();
}
}
// ---------------------------------------------------------------------------
// Extension state
// ---------------------------------------------------------------------------
interface UsageState extends UsageByProvider {
lastPoll: number;
activeProvider: ProviderKey | null;
}
interface PollOptions {
/** Override the disk-cache TTL for this poll (default: CACHE_TTL_MS). */
cacheTtl?: number;
/**
* Skip the shared disk-cache TTL check entirely and always fetch from the
* API. Used after an account switch where the cached data belongs to a
* different account.
*/
forceFresh?: boolean;
}
export default function (pi: ExtensionAPI) {
const endpoints = resolveUsageEndpoints();
const state: UsageState = {
codex: null,
claude: null,
zai: null,
gemini: null,
antigravity: null,
lastPoll: 0,
activeProvider: null,
codex: null, claude: null, zai: null, gemini: null, antigravity: null,
lastPoll: 0, activeProvider: null,
};
let pollInFlight: Promise<void> | null = null;
let pollQueued = false;
/** Timer running during an active agent loop to refresh usage periodically. */
let streamingTimer: ReturnType<typeof setInterval> | null = null;
let ctx: any = null;
function renderPercent(theme: any, value: number): string {
const v = clampPercent(value);
return theme.fg(colorForPercent(v), `${v}%`);
}
function renderBar(theme: any, value: number): string {
const v = clampPercent(value);
const width = 8;
const filled = Math.round((v / 100) * width);
const full = "█".repeat(Math.max(0, Math.min(width, filled)));
const empty = "░".repeat(Math.max(0, width - filled));
return theme.fg(colorForPercent(v), full) + theme.fg("dim", empty);
}
function pickDataForProvider(provider: ProviderKey | null): UsageData | null {
if (!provider) return null;
return state[provider];
}
// ---------------------------------------------------------------------------
// Status update
// ---------------------------------------------------------------------------
function updateStatus() {
const active = state.activeProvider;
const data = pickDataForProvider(active);
const data = active ? state[active] : null;
// Always emit event for other extensions (e.g. footer-display)
if (data && !data.error) {
pi.events.emit("usage:update", {
session: data.session,
@@ -370,6 +315,8 @@ export default function (pi: ExtensionAPI) {
if (!ctx?.hasUI) return;
const theme = ctx.ui.theme;
if (!active) {
ctx.ui.setStatus(STATUS_KEY, undefined);
return;
@@ -381,128 +328,92 @@ export default function (pi: ExtensionAPI) {
return;
}
const theme = ctx.ui.theme;
const label = PROVIDER_LABELS[active];
if (!data) {
ctx.ui.setStatus(STATUS_KEY, theme.fg("dim", `${label} usage: loading…`));
ctx.ui.setStatus(STATUS_KEY, theme.fg("dim", "loading\u2026"));
return;
}
if (data.error) {
const cache = readUsageCache();
const blockedUntil = active ? (cache?.rateLimitedUntil?.[active] ?? 0) : 0;
const backoffNote = blockedUntil > Date.now()
? ` retry in ${Math.ceil((blockedUntil - Date.now()) / 60000)}m`
: "";
ctx.ui.setStatus(STATUS_KEY, theme.fg("warning", `${label} usage unavailable (${data.error}${backoffNote})`));
const blockedUntil = cache?.rateLimitedUntil?.[active] ?? 0;
const note = blockedUntil > Date.now()
? ` \u2014 retry in ${Math.ceil((blockedUntil - Date.now()) / 60000)}m` : "";
ctx.ui.setStatus(STATUS_KEY, theme.fg("warning", `${PROVIDER_LABELS[active]} unavailable${note}`));
return;
}
const session = clampPercent(data.session);
const weekly = clampPercent(data.weekly);
const sessionReset = data.sessionResetsIn ? theme.fg("dim", `${data.sessionResetsIn}`) : "";
const weeklyReset = data.weeklyResetsIn ? theme.fg("dim", `${data.weeklyResetsIn}`) : "";
let s = theme.fg("muted", "S ") + renderBrailleBar(theme, session) + " " + theme.fg("dim", `${session}%`);
if (data.sessionResetsIn) s += " " + theme.fg("dim", data.sessionResetsIn);
const status =
theme.fg("dim", `${label} `) +
theme.fg("muted", "S ") +
renderBar(theme, session) +
" " +
renderPercent(theme, session) +
sessionReset +
theme.fg("muted", " W ") +
renderBar(theme, weekly) +
" " +
renderPercent(theme, weekly) +
weeklyReset;
let w = theme.fg("muted", "W ") + renderBrailleBar(theme, weekly) + " " + theme.fg("dim", `${weekly}%`);
if (data.weeklyResetsIn) w += " " + theme.fg("dim", `\u27F3 ${data.weeklyResetsIn}`);
ctx.ui.setStatus(STATUS_KEY, status);
ctx.ui.setStatus(STATUS_KEY, s + theme.fg("dim", " | ") + w);
}
function updateProviderFrom(modelLike: any): boolean {
const previous = state.activeProvider;
state.activeProvider = detectProvider(modelLike);
if (previous !== state.activeProvider) {
updateStatus();
return true;
}
if (previous !== state.activeProvider) { updateStatus(); return true; }
return false;
}
// ---------------------------------------------------------------------------
// Polling
// ---------------------------------------------------------------------------
async function runPoll(options: PollOptions = {}) {
const auth = readAuth();
const active = state.activeProvider;
if (!canShowForProvider(active, auth, endpoints) || !auth || !active) {
state.lastPoll = Date.now();
updateStatus();
return;
state.lastPoll = Date.now(); updateStatus(); return;
}
const cache = readUsageCache();
const now = Date.now();
const cacheTtl = options.cacheTtl ?? CACHE_TTL_MS;
// Respect cross-session rate-limit back-off written by any session.
const blockedUntil = cache?.rateLimitedUntil?.[active] ?? 0;
if (now < blockedUntil) {
if (cache?.data?.[active]) state[active] = cache.data[active]!;
state.lastPoll = now;
updateStatus();
return;
state.lastPoll = now; updateStatus(); return;
}
// Use shared disk cache unless forceFresh is set (e.g. after account switch).
if (!options.forceFresh && cache && now - cache.timestamp < cacheTtl && cache.data?.[active]) {
state[active] = cache.data[active]!;
state.lastPoll = now;
updateStatus();
return;
state.lastPoll = now; updateStatus(); return;
}
// --- Proactive token refresh ---
const oauthId = providerToOAuthProviderId(active);
let effectiveAuth = auth;
if (oauthId && active !== "zai") {
const creds = auth[oauthId as keyof typeof auth] as
| { access?: string; refresh?: string; expires?: number }
| undefined;
| { access?: string; refresh?: string; expires?: number } | undefined;
const expires = typeof creds?.expires === "number" ? creds.expires : 0;
const tokenExpiredOrMissing =
!creds?.access || (expires > 0 && Date.now() + 60_000 >= expires);
const tokenExpiredOrMissing = !creds?.access || (expires > 0 && Date.now() + 60_000 >= expires);
if (tokenExpiredOrMissing && creds?.refresh) {
try {
const refreshed = await ensureFreshAuthForProviders([oauthId as OAuthProviderId], {
auth,
persist: true,
});
const refreshed = await ensureFreshAuthForProviders([oauthId as OAuthProviderId], { auth, persist: true });
if (refreshed.auth) effectiveAuth = refreshed.auth;
} catch {
// Ignore refresh errors — fall through with existing auth
}
} catch {}
}
}
let result: UsageData;
if (active === "codex") {
const access = effectiveAuth["openai-codex"]?.access;
result = access
? await fetchCodexUsage(access)
result = access ? await fetchCodexUsage(access)
: { session: 0, weekly: 0, error: "missing access token (try /login again)" };
} else if (active === "claude") {
const access = effectiveAuth.anthropic?.access;
result = access
? await fetchClaudeUsage(access)
result = access ? await fetchClaudeUsage(access)
: { session: 0, weekly: 0, error: "missing access token (try /login again)" };
} else if (active === "zai") {
const token = effectiveAuth.zai?.access || effectiveAuth.zai?.key;
result = token
? await fetchZaiUsage(token, { endpoints })
result = token ? await fetchZaiUsage(token, { endpoints })
: { session: 0, weekly: 0, error: "missing token (try /login again)" };
} else if (active === "gemini") {
const creds = effectiveAuth["google-gemini-cli"];
@@ -523,14 +434,10 @@ export default function (pi: ExtensionAPI) {
const nextCache: import("./core").UsageCache = {
timestamp: cache?.timestamp ?? now,
data: { ...(cache?.data ?? {}) },
rateLimitedUntil: {
...(cache?.rateLimitedUntil ?? {}),
[active]: now + RATE_LIMITED_BACKOFF_MS,
},
rateLimitedUntil: { ...(cache?.rateLimitedUntil ?? {}), [active]: now + RATE_LIMITED_BACKOFF_MS },
};
writeUsageCache(nextCache);
}
// All other errors: don't update cache — next turn will retry from scratch.
} else {
const nextCache: import("./core").UsageCache = {
timestamp: now,
@@ -546,41 +453,24 @@ export default function (pi: ExtensionAPI) {
}
async function poll(options: PollOptions = {}) {
if (pollInFlight) {
pollQueued = true;
await pollInFlight;
return;
}
if (pollInFlight) { pollQueued = true; await pollInFlight; return; }
do {
pollQueued = false;
pollInFlight = runPoll(options)
.catch(() => {
// Never crash extension event handlers on transient polling errors.
})
.finally(() => {
pollInFlight = null;
});
pollInFlight = runPoll(options).catch(() => {}).finally(() => { pollInFlight = null; });
await pollInFlight;
} while (pollQueued);
}
function startStreamingTimer() {
if (streamingTimer !== null) return; // already running
streamingTimer = setInterval(() => {
void poll({ cacheTtl: ACTIVE_CACHE_TTL_MS });
}, STREAMING_POLL_INTERVAL_MS);
if (streamingTimer !== null) return;
streamingTimer = setInterval(() => { void poll({ cacheTtl: ACTIVE_CACHE_TTL_MS }); }, STREAMING_POLL_INTERVAL_MS);
}
function stopStreamingTimer() {
if (streamingTimer !== null) {
clearInterval(streamingTimer);
streamingTimer = null;
}
if (streamingTimer !== null) { clearInterval(streamingTimer); streamingTimer = null; }
}
// ── Session lifecycle ────────────────────────────────────────────────────
// ── Lifecycle ────────────────────────────────────────────────────────────
pi.on("session_start", async (_event, _ctx) => {
ctx = _ctx;
@@ -590,63 +480,34 @@ export default function (pi: ExtensionAPI) {
pi.on("session_shutdown", async (_event, _ctx) => {
stopStreamingTimer();
if (_ctx?.hasUI) {
_ctx.ui.setStatus(STATUS_KEY, undefined);
}
if (_ctx?.hasUI) _ctx.ui.setStatus(STATUS_KEY, undefined);
});
// ── Model change ─────────────────────────────────────────────────────────
pi.on("model_select", async (event, _ctx) => {
ctx = _ctx;
const changed = updateProviderFrom(event.model ?? _ctx.model);
if (changed) await poll();
});
// Keep provider detection up-to-date across turns (cheap, no API call unless
// the provider changed).
pi.on("turn_start", (_event, _ctx) => {
ctx = _ctx;
updateProviderFrom(_ctx.model);
});
pi.on("turn_start", (_event, _ctx) => { ctx = _ctx; updateProviderFrom(_ctx.model); });
// ── Agent loop ───────────────────────────────────────────────────────────
// Poll when the user submits a prompt — captures usage right before the
// new turn (useful to see current state before tokens are consumed).
pi.on("before_agent_start", async (_event, _ctx) => {
ctx = _ctx;
await poll({ cacheTtl: ACTIVE_CACHE_TTL_MS });
});
// Start interval timer so usage stays fresh during long-running agents
// (lots of tool calls, extended thinking, etc.).
pi.on("agent_start", (_event, _ctx) => {
ctx = _ctx;
startStreamingTimer();
});
pi.on("agent_start", (_event, _ctx) => { ctx = _ctx; startStreamingTimer(); });
// When the agent finishes, stop the timer and do a final fresh poll to
// capture the usage that was just consumed.
pi.on("agent_end", async (_event, _ctx) => {
ctx = _ctx;
stopStreamingTimer();
await poll({ cacheTtl: ACTIVE_CACHE_TTL_MS });
});
// ── Account switch ───────────────────────────────────────────────────────
// Re-poll immediately when the Claude account is switched via /switch-claude.
// We must invalidate the claude entry in the shared disk cache first —
// otherwise poll() would serve the previous account's data which is still
// within CACHE_TTL_MS.
pi.events.on("claude-account:switched", () => {
const cache = readUsageCache();
if (cache?.data?.claude) {
const nextCache: import("./core").UsageCache = {
...cache,
data: { ...cache.data },
};
const nextCache: import("./core").UsageCache = { ...cache, data: { ...cache.data } };
delete nextCache.data.claude;
writeUsageCache(nextCache);
}
@@ -660,18 +521,14 @@ export default function (pi: ExtensionAPI) {
handler: async (_args, _ctx) => {
ctx = _ctx;
updateProviderFrom(_ctx.model);
try {
if (_ctx?.hasUI) {
await _ctx.ui.custom<void>((tui, theme, _keybindings, done) => {
const selector = new UsageSelectorComponent(
tui,
theme,
state.activeProvider,
return new UsageSelectorComponent(
tui, theme, state.activeProvider,
() => fetchAllUsages({ endpoints }),
() => done(),
);
return selector;
});
}
} finally {

View File

@@ -0,0 +1,444 @@
/**
* WezTerm Theme Sync Extension
*
* Syncs pi theme with WezTerm terminal colors on startup.
*
* How it works:
* 1. Finds the WezTerm config directory (via $WEZTERM_CONFIG_DIR or defaults)
* 2. Runs the config through luajit to extract effective colors
* 3. Maps ANSI palette slots to pi theme colors
* 4. Writes a pi theme file and activates it
*
* Supports:
* - Inline `config.colors = { ... }` definitions
* - Lua theme modules loaded via require()
* - Any config structure as long as `config.colors` is set
*
* ANSI slots (consistent across themes):
* 0: black 8: bright black (gray/muted)
* 1: red 9: bright red
* 2: green 10: bright green
* 3: yellow 11: bright yellow
* 4: blue 12: bright blue
* 5: magenta 13: bright magenta
* 6: cyan 14: bright cyan
* 7: white 15: bright white
*
* Requirements:
* - WezTerm installed and running (sets $WEZTERM_CONFIG_DIR)
* - luajit or lua available in PATH
*/
import { execSync } from "node:child_process";
import { createHash } from "node:crypto";
import { existsSync, mkdirSync, readdirSync, unlinkSync, writeFileSync } from "node:fs";
import { join } from "node:path";
import { homedir } from "node:os";
import type { ExtensionAPI } from "@mariozechner/pi-coding-agent";
interface WeztermColors {
background: string;
foreground: string;
palette: Record<number, string>;
}
/**
* Find the WezTerm config directory.
* Checks $WEZTERM_CONFIG_DIR, then standard locations.
*/
function findConfigDir(): string | null {
if (process.env.WEZTERM_CONFIG_DIR && existsSync(process.env.WEZTERM_CONFIG_DIR)) {
return process.env.WEZTERM_CONFIG_DIR;
}
const candidates = [
join(homedir(), ".config", "wezterm"),
join(homedir(), ".wezterm"),
];
for (const dir of candidates) {
if (existsSync(dir)) return dir;
}
return null;
}
/**
* Find which Lua interpreter is available.
*/
function findLua(): string | null {
for (const cmd of ["luajit", "lua5.4", "lua5.3", "lua"]) {
try {
execSync(`which ${cmd}`, { stdio: "pipe" });
return cmd;
} catch {
// Try next
}
}
return null;
}
/**
* Extract colors from WezTerm config by evaluating it with a mocked wezterm module.
* Writes a temporary Lua helper script, runs it with luajit, then cleans up.
*/
function getWeztermColors(configDir: string, lua: string): WeztermColors | null {
const configFile = join(configDir, "wezterm.lua");
if (!existsSync(configFile)) return null;
const tmpScript = join(configDir, ".pi-extract-colors.lua");
const extractScript = `
-- Mock wezterm module with commonly used functions
local mock_wezterm = {
font = function(name) return name end,
font_with_fallback = function(names) return names end,
hostname = function() return "mock" end,
home_dir = ${JSON.stringify(homedir())},
config_dir = ${JSON.stringify(configDir)},
target_triple = "x86_64-unknown-linux-gnu",
version = "mock",
log_info = function() end,
log_warn = function() end,
log_error = function() end,
on = function() end,
action = setmetatable({}, {
__index = function(_, k)
return function(...) return { action = k, args = {...} } end
end
}),
action_callback = function(fn) return fn end,
color = {
parse = function(c) return c end,
get_builtin_schemes = function() return {} end,
},
gui = {
get_appearance = function() return "Dark" end,
},
GLOBAL = {},
nerdfonts = setmetatable({}, { __index = function() return "" end }),
}
mock_wezterm.plugin = { require = function() return {} end }
package.loaded["wezterm"] = mock_wezterm
-- Add config dir to Lua search path
package.path = ${JSON.stringify(configDir)} .. "/?.lua;" ..
${JSON.stringify(configDir)} .. "/?/init.lua;" ..
package.path
-- Try to load the config
local ok, config = pcall(dofile, ${JSON.stringify(configFile)})
if not ok then
io.stderr:write("Failed to load config: " .. tostring(config) .. "\\n")
os.exit(1)
end
if type(config) ~= "table" then
io.stderr:write("Config did not return a table\\n")
os.exit(1)
end
local colors = config.colors
if not colors then
if config.color_scheme then
io.stderr:write("color_scheme=" .. tostring(config.color_scheme) .. "\\n")
end
io.stderr:write("No inline colors found in config\\n")
os.exit(1)
end
if type(colors) == "table" then
if colors.background then print("background=" .. colors.background) end
if colors.foreground then print("foreground=" .. colors.foreground) end
if colors.ansi then
for i, c in ipairs(colors.ansi) do
print("ansi" .. (i-1) .. "=" .. c)
end
end
if colors.brights then
for i, c in ipairs(colors.brights) do
print("bright" .. (i-1) .. "=" .. c)
end
end
end
`;
try {
writeFileSync(tmpScript, extractScript);
const output = execSync(`${lua} ${JSON.stringify(tmpScript)}`, {
encoding: "utf-8",
timeout: 5000,
cwd: configDir,
stdio: ["pipe", "pipe", "pipe"],
});
return parseWeztermOutput(output);
} catch (err: any) {
if (err.stderr) {
console.error(`[wezterm-theme-sync] ${err.stderr.trim()}`);
}
return null;
} finally {
try { unlinkSync(tmpScript); } catch { /* ignore */ }
}
}
function parseWeztermOutput(output: string): WeztermColors {
const colors: WeztermColors = {
background: "#1e1e1e",
foreground: "#d4d4d4",
palette: {},
};
for (const line of output.split("\n")) {
const match = line.match(/^(\w+)=(.+)$/);
if (!match) continue;
const [, key, value] = match;
const color = normalizeColor(value.trim());
if (key === "background") {
colors.background = color;
} else if (key === "foreground") {
colors.foreground = color;
} else {
const ansiMatch = key.match(/^ansi(\d+)$/);
const brightMatch = key.match(/^bright(\d+)$/);
if (ansiMatch) {
const idx = parseInt(ansiMatch[1], 10);
if (idx >= 0 && idx <= 7) colors.palette[idx] = color;
} else if (brightMatch) {
const idx = parseInt(brightMatch[1], 10);
if (idx >= 0 && idx <= 7) colors.palette[idx + 8] = color;
}
}
}
return colors;
}
function normalizeColor(color: string): string {
const trimmed = color.trim();
if (trimmed.startsWith("#")) {
if (trimmed.length === 4) {
return `#${trimmed[1]}${trimmed[1]}${trimmed[2]}${trimmed[2]}${trimmed[3]}${trimmed[3]}`;
}
return trimmed.toLowerCase();
}
if (/^[0-9a-fA-F]{6}$/.test(trimmed)) {
return `#${trimmed}`.toLowerCase();
}
return `#${trimmed}`.toLowerCase();
}
function hexToRgb(hex: string): { r: number; g: number; b: number } {
const h = hex.replace("#", "");
return {
r: parseInt(h.substring(0, 2), 16),
g: parseInt(h.substring(2, 4), 16),
b: parseInt(h.substring(4, 6), 16),
};
}
function rgbToHex(r: number, g: number, b: number): string {
const clamp = (n: number) => Math.round(Math.min(255, Math.max(0, n)));
return `#${clamp(r).toString(16).padStart(2, "0")}${clamp(g).toString(16).padStart(2, "0")}${clamp(b).toString(16).padStart(2, "0")}`;
}
function getLuminance(hex: string): number {
const { r, g, b } = hexToRgb(hex);
return (0.299 * r + 0.587 * g + 0.114 * b) / 255;
}
function adjustBrightness(hex: string, amount: number): string {
const { r, g, b } = hexToRgb(hex);
return rgbToHex(r + amount, g + amount, b + amount);
}
function mixColors(color1: string, color2: string, weight: number): string {
const c1 = hexToRgb(color1);
const c2 = hexToRgb(color2);
return rgbToHex(
c1.r * weight + c2.r * (1 - weight),
c1.g * weight + c2.g * (1 - weight),
c1.b * weight + c2.b * (1 - weight),
);
}
function generatePiTheme(colors: WeztermColors, themeName: string): object {
const bg = colors.background;
const fg = colors.foreground;
const isDark = getLuminance(bg) < 0.5;
// ANSI color slots - trust the standard for semantic colors
const error = colors.palette[1] || "#cc6666";
const success = colors.palette[2] || "#98c379";
const warning = colors.palette[3] || "#e5c07b";
const link = colors.palette[4] || "#61afef";
const accent = colors.palette[5] || "#c678dd";
const accentAlt = colors.palette[6] || "#56b6c2";
// Derive neutrals from bg/fg for consistent readability
const muted = mixColors(fg, bg, 0.65);
const dim = mixColors(fg, bg, 0.45);
const borderMuted = mixColors(fg, bg, 0.25);
// Derive backgrounds
const bgShift = isDark ? 12 : -12;
const selectedBg = adjustBrightness(bg, bgShift);
const userMsgBg = adjustBrightness(bg, Math.round(bgShift * 0.7));
const toolPendingBg = adjustBrightness(bg, Math.round(bgShift * 0.4));
const toolSuccessBg = mixColors(bg, success, 0.88);
const toolErrorBg = mixColors(bg, error, 0.88);
const customMsgBg = mixColors(bg, accent, 0.92);
return {
$schema:
"https://raw.githubusercontent.com/badlogic/pi-mono/main/packages/coding-agent/src/modes/interactive/theme/theme-schema.json",
name: themeName,
vars: {
bg,
fg,
accent,
accentAlt,
link,
error,
success,
warning,
muted,
dim,
borderMuted,
selectedBg,
userMsgBg,
toolPendingBg,
toolSuccessBg,
toolErrorBg,
customMsgBg,
},
colors: {
accent: "accent",
border: "link",
borderAccent: "accent",
borderMuted: "borderMuted",
success: "success",
error: "error",
warning: "warning",
muted: "muted",
dim: "dim",
text: "",
thinkingText: "muted",
selectedBg: "selectedBg",
userMessageBg: "userMsgBg",
userMessageText: "",
customMessageBg: "customMsgBg",
customMessageText: "",
customMessageLabel: "accent",
toolPendingBg: "toolPendingBg",
toolSuccessBg: "toolSuccessBg",
toolErrorBg: "toolErrorBg",
toolTitle: "",
toolOutput: "muted",
mdHeading: "warning",
mdLink: "link",
mdLinkUrl: "dim",
mdCode: "accent",
mdCodeBlock: "success",
mdCodeBlockBorder: "muted",
mdQuote: "muted",
mdQuoteBorder: "muted",
mdHr: "muted",
mdListBullet: "accent",
toolDiffAdded: "success",
toolDiffRemoved: "error",
toolDiffContext: "muted",
syntaxComment: "muted",
syntaxKeyword: "accent",
syntaxFunction: "link",
syntaxVariable: "accentAlt",
syntaxString: "success",
syntaxNumber: "accent",
syntaxType: "accentAlt",
syntaxOperator: "fg",
syntaxPunctuation: "muted",
thinkingOff: "borderMuted",
thinkingMinimal: "muted",
thinkingLow: "link",
thinkingMedium: "accentAlt",
thinkingHigh: "accent",
thinkingXhigh: "accent",
bashMode: "success",
},
export: {
pageBg: isDark ? adjustBrightness(bg, -8) : adjustBrightness(bg, 8),
cardBg: bg,
infoBg: mixColors(bg, warning, 0.88),
},
};
}
function computeThemeHash(colors: WeztermColors): string {
const parts: string[] = [];
parts.push(`bg=${colors.background}`);
parts.push(`fg=${colors.foreground}`);
for (let i = 0; i <= 15; i++) {
parts.push(`p${i}=${colors.palette[i] ?? ""}`);
}
return createHash("sha1").update(parts.join("\n")).digest("hex").slice(0, 8);
}
function cleanupOldThemes(themesDir: string, keepFile: string): void {
try {
for (const file of readdirSync(themesDir)) {
if (file === keepFile) continue;
if (file.startsWith("wezterm-sync-") && file.endsWith(".json")) {
unlinkSync(join(themesDir, file));
}
}
} catch {
// Best-effort cleanup
}
}
export default function (pi: ExtensionAPI) {
pi.on("session_start", async (_event, ctx) => {
const configDir = findConfigDir();
if (!configDir) {
return;
}
const lua = findLua();
if (!lua) {
return;
}
const colors = getWeztermColors(configDir, lua);
if (!colors) {
return;
}
const themesDir = join(homedir(), ".pi", "agent", "themes");
if (!existsSync(themesDir)) {
mkdirSync(themesDir, { recursive: true });
}
const hash = computeThemeHash(colors);
const themeName = `wezterm-sync-${hash}`;
const themeFile = `${themeName}.json`;
const themePath = join(themesDir, themeFile);
// Skip if already on the correct synced theme (avoids repaint)
if (ctx.ui.theme.name === themeName) {
return;
}
const themeJson = generatePiTheme(colors, themeName);
writeFileSync(themePath, JSON.stringify(themeJson, null, 2));
// Remove old generated themes
cleanupOldThemes(themesDir, themeFile);
// Set by name so pi loads from the file we just wrote
const result = ctx.ui.setTheme(themeName);
if (!result.success) {
ctx.ui.notify(`WezTerm theme sync failed: ${result.error}`, "error");
}
});
}

View File

@@ -13,7 +13,7 @@
"mcp"
],
"directTools": true,
"lifecycle": "eager"
"lifecycle": "lazy"
}
}
}

View File

@@ -0,0 +1,81 @@
{
"$schema": "https://raw.githubusercontent.com/badlogic/pi-mono/main/packages/coding-agent/src/modes/interactive/theme/theme-schema.json",
"name": "wezterm-sync-ba8a76f5",
"vars": {
"bg": "#1c2433",
"fg": "#afbbd2",
"accent": "#b78aff",
"accentAlt": "#ff955c",
"link": "#69c3ff",
"error": "#ff738a",
"success": "#3cec85",
"warning": "#eacd61",
"muted": "#7c869a",
"dim": "#5e687b",
"borderMuted": "#414a5b",
"selectedBg": "#28303f",
"userMsgBg": "#242c3b",
"toolPendingBg": "#212938",
"toolSuccessBg": "#203c3d",
"toolErrorBg": "#372d3d",
"customMsgBg": "#282c43"
},
"colors": {
"accent": "accent",
"border": "link",
"borderAccent": "accent",
"borderMuted": "borderMuted",
"success": "success",
"error": "error",
"warning": "warning",
"muted": "muted",
"dim": "dim",
"text": "",
"thinkingText": "muted",
"selectedBg": "selectedBg",
"userMessageBg": "userMsgBg",
"userMessageText": "",
"customMessageBg": "customMsgBg",
"customMessageText": "",
"customMessageLabel": "accent",
"toolPendingBg": "toolPendingBg",
"toolSuccessBg": "toolSuccessBg",
"toolErrorBg": "toolErrorBg",
"toolTitle": "",
"toolOutput": "muted",
"mdHeading": "warning",
"mdLink": "link",
"mdLinkUrl": "dim",
"mdCode": "accent",
"mdCodeBlock": "success",
"mdCodeBlockBorder": "muted",
"mdQuote": "muted",
"mdQuoteBorder": "muted",
"mdHr": "muted",
"mdListBullet": "accent",
"toolDiffAdded": "success",
"toolDiffRemoved": "error",
"toolDiffContext": "muted",
"syntaxComment": "muted",
"syntaxKeyword": "accent",
"syntaxFunction": "link",
"syntaxVariable": "accentAlt",
"syntaxString": "success",
"syntaxNumber": "accent",
"syntaxType": "accentAlt",
"syntaxOperator": "fg",
"syntaxPunctuation": "muted",
"thinkingOff": "borderMuted",
"thinkingMinimal": "muted",
"thinkingLow": "link",
"thinkingMedium": "accentAlt",
"thinkingHigh": "accent",
"thinkingXhigh": "accent",
"bashMode": "success"
},
"export": {
"pageBg": "#141c2b",
"cardBg": "#1c2433",
"infoBg": "#353839"
}
}

3
pi/.pi/settings.json Normal file
View File

@@ -0,0 +1,3 @@
{
"hide_thinking_block": true
}