Implementer retry med backoff og token-budsjett i synops-agent

- Retry med exponential backoff for retryable API-feil (429, 500, 502, 503)
  med konfigurerbar --max-retries (default: 3) og Retry-After-støtte
- --max-cost flagg for token-budsjett (USD), stopper og rapporterer
  gjenstående arbeid ved budsjettgrense (exit code 2)
- Konfigurerbar --max-tokens per provider (erstatter hardkodet 4096/8192)
- Sanntids kostnadsregnskap per modell med cost_per_million_tokens-tabell
- Detaljert token/kostnad-rapport ved avslutning

Ref: docs/proposals/agent_harness.md §3 (selvovervåking)
This commit is contained in:
vegard 2026-03-19 18:12:27 +00:00
parent 1a6887f334
commit 0bfad1eb8a
2 changed files with 249 additions and 17 deletions

View file

@ -14,7 +14,10 @@ mod tools;
use clap::Parser; use clap::Parser;
use context::{CompactionConfig, CompactionLevel, check_compaction_level, compact_messages}; use context::{CompactionConfig, CompactionLevel, check_compaction_level, compact_messages};
use provider::{ApiKeys, CompletionResponse, Message, TokenUsage, create_provider}; use provider::{
ApiKeys, Message, RetryConfig, TokenUsage,
calculate_cost, complete_with_retry, create_provider,
};
use std::collections::HashMap; use std::collections::HashMap;
use std::path::PathBuf; use std::path::PathBuf;
@ -48,6 +51,18 @@ struct Cli {
/// Spawn Claude Code i stedet (bruker abonnement) /// Spawn Claude Code i stedet (bruker abonnement)
#[arg(long)] #[arg(long)]
claude: bool, claude: bool,
/// Maks kostnad i USD (f.eks. 0.50). Stopper og rapporterer ved grense.
#[arg(long)]
max_cost: Option<f64>,
/// Maks output-tokens per LLM-kall (overstyrer provider-default)
#[arg(long)]
max_tokens: Option<u32>,
/// Maks antall retries ved API-feil (default: 3)
#[arg(long, default_value = "3")]
max_retries: u32,
} }
#[tokio::main] #[tokio::main]
@ -67,7 +82,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
} }
let api_keys = ApiKeys::from_env(); let api_keys = ApiKeys::from_env();
let provider = create_provider(&cli.model, &api_keys)?; let provider = create_provider(&cli.model, &api_keys, cli.max_tokens)?;
tracing::info!( tracing::info!(
model = provider.model_id(), model = provider.model_id(),
@ -98,7 +113,15 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Token accounting // Token accounting
let mut total_usage: HashMap<String, TokenUsage> = HashMap::new(); let mut total_usage: HashMap<String, TokenUsage> = HashMap::new();
let mut total_cost: f64 = 0.0;
let mut iteration = 0; let mut iteration = 0;
let mut budget_exhausted = false;
// Retry config
let retry_config = RetryConfig {
max_retries: cli.max_retries,
..Default::default()
};
// Context compaction config // Context compaction config
let compaction_config = CompactionConfig { let compaction_config = CompactionConfig {
@ -107,7 +130,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
}; };
tracing::info!( tracing::info!(
context_window = compaction_config.context_window, context_window = compaction_config.context_window,
"ACC konfigurert" max_cost = cli.max_cost.map(|c| format!("${:.2}", c)).as_deref().unwrap_or("unlimited"),
"Agent konfigurert"
); );
// === Agent loop === // === Agent loop ===
@ -118,27 +142,49 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
break; break;
} }
// Call LLM // Budget check before calling LLM
let response: CompletionResponse = provider.complete(&messages, &tool_defs).await?; if let Some(max_cost) = cli.max_cost {
if total_cost >= max_cost {
budget_exhausted = true;
tracing::warn!(
cost = format!("${:.4}", total_cost),
budget = format!("${:.2}", max_cost),
"Budsjettgrense nådd — stopper"
);
break;
}
}
// Accumulate token usage // Call LLM with retry
let response = complete_with_retry(
provider.as_ref(),
&messages,
&tool_defs,
&retry_config,
).await?;
// Accumulate token usage and cost
let entry = total_usage let entry = total_usage
.entry(response.model.clone()) .entry(response.model.clone())
.or_insert_with(TokenUsage::default); .or_insert_with(TokenUsage::default);
entry.input_tokens += response.usage.input_tokens; entry.input_tokens += response.usage.input_tokens;
entry.output_tokens += response.usage.output_tokens; entry.output_tokens += response.usage.output_tokens;
let call_cost = calculate_cost(&response.model, &response.usage);
total_cost += call_cost;
if cli.verbose { if cli.verbose {
tracing::info!( tracing::info!(
iteration, iteration,
input = response.usage.input_tokens, input = response.usage.input_tokens,
output = response.usage.output_tokens, output = response.usage.output_tokens,
call_cost = format!("${:.4}", call_cost),
total_cost = format!("${:.4}", total_cost),
"LLM-kall" "LLM-kall"
); );
} }
// === Adaptive Context Compaction === // === Adaptive Context Compaction ===
// Use prompt_tokens from the API response as calibration anchor
let level = check_compaction_level(response.usage.input_tokens, &compaction_config); let level = check_compaction_level(response.usage.input_tokens, &compaction_config);
if level != CompactionLevel::None { if level != CompactionLevel::None {
let ratio = response.usage.input_tokens as f64 / compaction_config.context_window as f64; let ratio = response.usage.input_tokens as f64 / compaction_config.context_window as f64;
@ -212,16 +258,28 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let mut total_in = 0u64; let mut total_in = 0u64;
let mut total_out = 0u64; let mut total_out = 0u64;
for (model, usage) in &total_usage { for (model, usage) in &total_usage {
let model_cost = calculate_cost(model, usage);
eprintln!( eprintln!(
" {}: {} inn / {} ut", " {}: {} inn / {} ut (${:.4})",
model, usage.input_tokens, usage.output_tokens model, usage.input_tokens, usage.output_tokens, model_cost
); );
total_in += usage.input_tokens; total_in += usage.input_tokens;
total_out += usage.output_tokens; total_out += usage.output_tokens;
} }
eprintln!(" Totalt: {} inn / {} ut", total_in, total_out); eprintln!(" Totalt: {} inn / {} ut", total_in, total_out);
eprintln!(" Kostnad: ${:.4}", total_cost);
if let Some(max_cost) = cli.max_cost {
eprintln!(" Budsjett: ${:.2} ({:.0}% brukt)", max_cost, (total_cost / max_cost) * 100.0);
}
eprintln!(" Iterasjoner: {}", iteration); eprintln!(" Iterasjoner: {}", iteration);
if budget_exhausted {
eprintln!("\n⚠ Budsjettgrense nådd. Oppgaven er ikke fullført.");
eprintln!(" Gjenstående arbeid bør fortsettes med høyere --max-cost");
eprintln!(" eller manuelt. Kontekst kan gjenopprettes fra meldingsloggen.");
std::process::exit(2);
}
Ok(()) Ok(())
} }

View file

@ -96,6 +96,38 @@ pub enum ProviderError {
Parse(String), Parse(String),
#[error("No API key configured for {provider}")] #[error("No API key configured for {provider}")]
NoApiKey { provider: String }, NoApiKey { provider: String },
#[allow(dead_code)]
#[error("Budget exhausted: {0}")]
BudgetExhausted(String),
}
impl ProviderError {
/// Whether this error is retryable (transient API failures).
pub fn is_retryable(&self) -> bool {
match self {
ProviderError::Api { status, .. } => matches!(status, 429 | 500 | 502 | 503),
ProviderError::Http(e) => {
// Retry on connection/timeout errors
e.is_timeout() || e.is_connect()
}
_ => false,
}
}
/// Extract Retry-After hint (seconds) from error body if present.
pub fn retry_after_hint(&self) -> Option<u64> {
if let ProviderError::Api { status: 429, body } = self {
// Some APIs include retry_after in JSON error body
if let Ok(v) = serde_json::from_str::<serde_json::Value>(body) {
if let Some(secs) = v["error"]["retry_after"].as_u64()
.or_else(|| v["retry_after"].as_u64())
{
return Some(secs);
}
}
}
None
}
} }
// ============================================================================ // ============================================================================
@ -108,6 +140,7 @@ pub struct OpenAiCompatible {
api_key: String, api_key: String,
model: String, model: String,
provider_name: String, provider_name: String,
max_tokens: u32,
} }
impl OpenAiCompatible { impl OpenAiCompatible {
@ -123,9 +156,15 @@ impl OpenAiCompatible {
api_key: api_key.into(), api_key: api_key.into(),
model: model.into(), model: model.into(),
provider_name: provider_name.into(), provider_name: provider_name.into(),
max_tokens: 4096,
} }
} }
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = max_tokens;
self
}
/// OpenRouter /// OpenRouter
pub fn openrouter(api_key: impl Into<String>, model: impl Into<String>) -> Self { pub fn openrouter(api_key: impl Into<String>, model: impl Into<String>) -> Self {
Self::new( Self::new(
@ -177,7 +216,7 @@ impl LlmProvider for OpenAiCompatible {
let mut body = serde_json::json!({ let mut body = serde_json::json!({
"model": self.model, "model": self.model,
"messages": messages, "messages": messages,
"max_tokens": 4096, "max_tokens": self.max_tokens,
}); });
if !tools.is_empty() { if !tools.is_empty() {
@ -251,6 +290,7 @@ pub struct Anthropic {
client: reqwest::Client, client: reqwest::Client,
api_key: String, api_key: String,
model: String, model: String,
max_tokens: u32,
} }
impl Anthropic { impl Anthropic {
@ -259,8 +299,14 @@ impl Anthropic {
client: reqwest::Client::new(), client: reqwest::Client::new(),
api_key: api_key.into(), api_key: api_key.into(),
model: model.into(), model: model.into(),
max_tokens: 8192,
} }
} }
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = max_tokens;
self
}
} }
#[async_trait::async_trait] #[async_trait::async_trait]
@ -327,7 +373,7 @@ impl LlmProvider for Anthropic {
let mut body = serde_json::json!({ let mut body = serde_json::json!({
"model": self.model, "model": self.model,
"max_tokens": 8192, "max_tokens": self.max_tokens,
"messages": anthropic_messages, "messages": anthropic_messages,
}); });
@ -429,6 +475,7 @@ pub struct Gemini {
client: reqwest::Client, client: reqwest::Client,
api_key: String, api_key: String,
model: String, model: String,
max_tokens: u32,
} }
impl Gemini { impl Gemini {
@ -437,8 +484,14 @@ impl Gemini {
client: reqwest::Client::new(), client: reqwest::Client::new(),
api_key: api_key.into(), api_key: api_key.into(),
model: model.into(), model: model.into(),
max_tokens: 8192,
} }
} }
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = max_tokens;
self
}
} }
#[async_trait::async_trait] #[async_trait::async_trait]
@ -517,6 +570,9 @@ impl LlmProvider for Gemini {
let mut body = serde_json::json!({ let mut body = serde_json::json!({
"contents": contents, "contents": contents,
"generationConfig": {
"maxOutputTokens": self.max_tokens,
},
}); });
if !system.is_empty() { if !system.is_empty() {
@ -633,6 +689,7 @@ impl LlmProvider for Gemini {
pub fn create_provider( pub fn create_provider(
model_spec: &str, model_spec: &str,
api_keys: &ApiKeys, api_keys: &ApiKeys,
max_tokens: Option<u32>,
) -> Result<Box<dyn LlmProvider>, ProviderError> { ) -> Result<Box<dyn LlmProvider>, ProviderError> {
let (provider, model) = if let Some(idx) = model_spec.find('/') { let (provider, model) = if let Some(idx) = model_spec.find('/') {
(&model_spec[..idx], &model_spec[idx + 1..]) (&model_spec[..idx], &model_spec[idx + 1..])
@ -644,35 +701,152 @@ pub fn create_provider(
"openrouter" => { "openrouter" => {
let key = api_keys.openrouter.as_deref() let key = api_keys.openrouter.as_deref()
.ok_or_else(|| ProviderError::NoApiKey { provider: "openrouter".into() })?; .ok_or_else(|| ProviderError::NoApiKey { provider: "openrouter".into() })?;
Ok(Box::new(OpenAiCompatible::openrouter(key, model))) let mut p = OpenAiCompatible::openrouter(key, model);
if let Some(mt) = max_tokens { p = p.with_max_tokens(mt); }
Ok(Box::new(p))
} }
"anthropic" => { "anthropic" => {
let key = api_keys.anthropic.as_deref() let key = api_keys.anthropic.as_deref()
.ok_or_else(|| ProviderError::NoApiKey { provider: "anthropic".into() })?; .ok_or_else(|| ProviderError::NoApiKey { provider: "anthropic".into() })?;
Ok(Box::new(Anthropic::new(key, model))) let mut p = Anthropic::new(key, model);
if let Some(mt) = max_tokens { p = p.with_max_tokens(mt); }
Ok(Box::new(p))
} }
"gemini" | "google" => { "gemini" | "google" => {
let key = api_keys.gemini.as_deref() let key = api_keys.gemini.as_deref()
.ok_or_else(|| ProviderError::NoApiKey { provider: "gemini".into() })?; .ok_or_else(|| ProviderError::NoApiKey { provider: "gemini".into() })?;
Ok(Box::new(Gemini::new(key, model))) let mut p = Gemini::new(key, model);
if let Some(mt) = max_tokens { p = p.with_max_tokens(mt); }
Ok(Box::new(p))
} }
"xai" | "grok" => { "xai" | "grok" => {
let key = api_keys.xai.as_deref() let key = api_keys.xai.as_deref()
.ok_or_else(|| ProviderError::NoApiKey { provider: "xai".into() })?; .ok_or_else(|| ProviderError::NoApiKey { provider: "xai".into() })?;
Ok(Box::new(OpenAiCompatible::xai(key, model))) let mut p = OpenAiCompatible::xai(key, model);
if let Some(mt) = max_tokens { p = p.with_max_tokens(mt); }
Ok(Box::new(p))
} }
"openai" => { "openai" => {
let key = api_keys.openai.as_deref() let key = api_keys.openai.as_deref()
.ok_or_else(|| ProviderError::NoApiKey { provider: "openai".into() })?; .ok_or_else(|| ProviderError::NoApiKey { provider: "openai".into() })?;
Ok(Box::new(OpenAiCompatible::openai(key, model))) let mut p = OpenAiCompatible::openai(key, model);
if let Some(mt) = max_tokens { p = p.with_max_tokens(mt); }
Ok(Box::new(p))
} }
"ollama" | "local" => { "ollama" | "local" => {
Ok(Box::new(OpenAiCompatible::ollama(model))) let mut p = OpenAiCompatible::ollama(model);
if let Some(mt) = max_tokens { p = p.with_max_tokens(mt); }
Ok(Box::new(p))
} }
_ => Err(ProviderError::Parse(format!("Unknown provider: {}", provider))), _ => Err(ProviderError::Parse(format!("Unknown provider: {}", provider))),
} }
} }
// ============================================================================
// Cost estimation
// ============================================================================
/// Cost per million tokens (input, output) for known models.
/// Returns (0.0, 0.0) for unknown models.
pub fn cost_per_million_tokens(model: &str) -> (f64, f64) {
let m = model.to_lowercase();
// Anthropic
if m.contains("opus") { return (15.0, 75.0); }
if m.contains("sonnet") { return (3.0, 15.0); }
if m.contains("haiku") { return (0.25, 1.25); }
// Gemini
if m.contains("gemini") && m.contains("flash") { return (0.075, 0.30); }
if m.contains("gemini") && m.contains("pro") { return (1.25, 5.0); }
// Grok
if m.contains("grok-3") && m.contains("mini") { return (0.30, 0.50); }
if m.contains("grok-3") { return (3.0, 15.0); }
if m.contains("grok") { return (2.0, 10.0); }
// OpenAI
if m.contains("gpt-4o-mini") { return (0.15, 0.60); }
if m.contains("gpt-4o") { return (2.50, 10.0); }
if m.contains("o1") || m.contains("o3") { return (10.0, 40.0); }
// Local / unknown — free
(0.0, 0.0)
}
/// Calculate cost from token usage and model.
pub fn calculate_cost(model: &str, usage: &TokenUsage) -> f64 {
let (input_cpm, output_cpm) = cost_per_million_tokens(model);
(usage.input_tokens as f64 * input_cpm + usage.output_tokens as f64 * output_cpm) / 1_000_000.0
}
// ============================================================================
// Retry with exponential backoff
// ============================================================================
/// Retry configuration for API calls.
pub struct RetryConfig {
/// Maximum number of retries (default: 3).
pub max_retries: u32,
/// Base delay in milliseconds (default: 1000).
pub base_delay_ms: u64,
/// Maximum delay in milliseconds (default: 30000).
pub max_delay_ms: u64,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_retries: 3,
base_delay_ms: 1_000,
max_delay_ms: 30_000,
}
}
}
/// Call provider.complete() with retry + exponential backoff on retryable errors.
pub async fn complete_with_retry(
provider: &dyn LlmProvider,
messages: &[Message],
tools: &[ToolDef],
config: &RetryConfig,
) -> Result<CompletionResponse, ProviderError> {
let mut last_err = None;
for attempt in 0..=config.max_retries {
match provider.complete(messages, tools).await {
Ok(resp) => return Ok(resp),
Err(e) => {
if !e.is_retryable() || attempt == config.max_retries {
return Err(e);
}
// Calculate delay: exponential backoff with jitter
let base = config.base_delay_ms * 2u64.pow(attempt);
let delay = if let Some(retry_after) = e.retry_after_hint() {
// Respect Retry-After header from 429
retry_after * 1000
} else {
base.min(config.max_delay_ms)
};
tracing::warn!(
attempt = attempt + 1,
max_retries = config.max_retries,
delay_ms = delay,
error = %e,
"Retryable API error, backing off"
);
tokio::time::sleep(std::time::Duration::from_millis(delay)).await;
last_err = Some(e);
}
}
}
Err(last_err.unwrap())
}
/// API keys loaded from environment /// API keys loaded from environment
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Default)]
pub struct ApiKeys { pub struct ApiKeys {