Agent and websocket investigations work fine

This commit is contained in:
Harshavardhan Musanalli
2025-10-27 19:13:39 +01:00
parent 0a8b2dc202
commit 8832450a1f
8 changed files with 1694 additions and 19 deletions

204
agent.go
View File

@@ -102,7 +102,7 @@ func (a *LinuxDiagnosticAgent) DiagnoseIssue(issue string) error {
for { for {
// Send request to TensorZero API via OpenAI SDK // Send request to TensorZero API via OpenAI SDK
response, err := a.sendRequest(messages) response, err := a.sendRequestWithEpisode(messages, a.episodeID)
if err != nil { if err != nil {
return fmt.Errorf("failed to send request: %w", err) return fmt.Errorf("failed to send request: %w", err)
} }
@@ -115,34 +115,73 @@ func (a *LinuxDiagnosticAgent) DiagnoseIssue(issue string) error {
fmt.Printf("\nAI Response:\n%s\n", content) fmt.Printf("\nAI Response:\n%s\n", content)
// Parse the response to determine next action // Parse the response to determine next action
var diagnosticResp DiagnosticResponse var diagnosticResp EBPFEnhancedDiagnosticResponse
var resolutionResp ResolutionResponse var resolutionResp ResolutionResponse
// Try to parse as diagnostic response first // Try to parse as diagnostic response first (with eBPF support)
if err := json.Unmarshal([]byte(content), &diagnosticResp); err == nil && diagnosticResp.ResponseType == "diagnostic" { if err := json.Unmarshal([]byte(content), &diagnosticResp); err == nil && diagnosticResp.ResponseType == "diagnostic" {
// Handle diagnostic phase // Handle diagnostic phase
fmt.Printf("\nReasoning: %s\n", diagnosticResp.Reasoning) fmt.Printf("\nReasoning: %s\n", diagnosticResp.Reasoning)
if len(diagnosticResp.Commands) == 0 {
fmt.Println("No commands to execute in diagnostic phase")
break
}
// Execute commands and collect results // Execute commands and collect results
commandResults := make([]CommandResult, 0, len(diagnosticResp.Commands)) commandResults := make([]CommandResult, 0, len(diagnosticResp.Commands))
for _, cmd := range diagnosticResp.Commands { if len(diagnosticResp.Commands) > 0 {
fmt.Printf("\nExecuting command '%s': %s\n", cmd.ID, cmd.Command) fmt.Printf("🔧 Executing diagnostic commands...\n")
result := a.executor.Execute(cmd) for _, cmd := range diagnosticResp.Commands {
commandResults = append(commandResults, result) fmt.Printf("⚙️ Executing command '%s': %s\n", cmd.ID, cmd.Command)
result := a.executor.Execute(cmd)
commandResults = append(commandResults, result)
fmt.Printf("Output:\n%s\n", result.Output) if result.ExitCode == 0 {
if result.Error != "" { fmt.Printf("✅ Command '%s' completed successfully\n", cmd.ID)
fmt.Printf("Error: %s\n", result.Error) } else {
fmt.Printf("❌ Command '%s' failed with exit code %d\n", cmd.ID, result.ExitCode)
}
} }
} }
// Prepare command results as user message // Execute eBPF programs if present
resultsJSON, err := json.MarshalIndent(commandResults, "", " ") var ebpfResults []map[string]interface{}
if len(diagnosticResp.EBPFPrograms) > 0 {
fmt.Printf("🔬 Executing %d eBPF programs...\n", len(diagnosticResp.EBPFPrograms))
ebpfResults = a.executeEBPFPrograms(diagnosticResp.EBPFPrograms)
}
// Prepare combined results as user message
allResults := map[string]interface{}{
"command_results": commandResults,
"executed_commands": len(commandResults),
}
// Include eBPF results if any were executed
if len(ebpfResults) > 0 {
allResults["ebpf_results"] = ebpfResults
allResults["executed_ebpf_programs"] = len(ebpfResults)
// Extract evidence summary for TensorZero
evidenceSummary := make([]string, 0)
for _, result := range ebpfResults {
name := result["name"]
eventCount := result["data_points"]
description := result["description"]
status := result["status"]
summaryStr := fmt.Sprintf("%s: %v events (%s) - %s", name, eventCount, status, description)
evidenceSummary = append(evidenceSummary, summaryStr)
}
allResults["ebpf_evidence_summary"] = evidenceSummary
fmt.Printf("<22> Sending eBPF monitoring data to TensorZero:\n")
for _, summary := range evidenceSummary {
fmt.Printf(" - %s\n", summary)
}
fmt.Printf("✅ Executed %d commands, %d eBPF programs\n", len(commandResults), len(ebpfResults))
} else {
fmt.Printf("✅ Executed %d commands\n", len(commandResults))
}
resultsJSON, err := json.MarshalIndent(allResults, "", " ")
if err != nil { if err != nil {
return fmt.Errorf("failed to marshal command results: %w", err) return fmt.Errorf("failed to marshal command results: %w", err)
} }
@@ -178,6 +217,127 @@ func (a *LinuxDiagnosticAgent) DiagnoseIssue(issue string) error {
return nil return nil
} }
// executeEBPFPrograms executes REAL eBPF monitoring programs using the actual eBPF manager
func (a *LinuxDiagnosticAgent) executeEBPFPrograms(ebpfPrograms []EBPFRequest) []map[string]interface{} {
var results []map[string]interface{}
if a.ebpfManager == nil {
fmt.Printf("❌ eBPF manager not initialized\n")
return results
}
for _, prog := range ebpfPrograms {
fmt.Printf("🔬 Starting eBPF program [%s]: %s -> %s (%ds)\n", prog.Name, prog.Type, prog.Target, int(prog.Duration))
// Actually start the eBPF program using the real manager
programID, err := a.ebpfManager.StartEBPFProgram(prog)
if err != nil {
fmt.Printf("❌ Failed to start eBPF program [%s]: %v\n", prog.Name, err)
result := map[string]interface{}{
"name": prog.Name,
"type": prog.Type,
"target": prog.Target,
"duration": int(prog.Duration),
"description": prog.Description,
"status": "failed",
"error": err.Error(),
"success": false,
}
results = append(results, result)
continue
}
// Let the eBPF program run for the specified duration
fmt.Printf("⏰ Waiting %d seconds for eBPF program to collect data...\n", int(prog.Duration))
time.Sleep(time.Duration(prog.Duration) * time.Second)
// Give the collectEvents goroutine a moment to finish and store results
fmt.Printf("⏳ Allowing program to complete data collection...\n")
time.Sleep(500 * time.Millisecond)
// Get the results (should be in completedResults now)
fmt.Printf("📊 Getting results for eBPF program [%s]...\n", prog.Name)
// Use a channel to implement timeout for GetProgramResults
type resultPair struct {
trace *EBPFTrace
err error
}
resultChan := make(chan resultPair, 1)
go func() {
trace, err := a.ebpfManager.GetProgramResults(programID)
resultChan <- resultPair{trace, err}
}()
var trace *EBPFTrace
var resultErr error
select {
case result := <-resultChan:
trace = result.trace
resultErr = result.err
case <-time.After(3 * time.Second):
resultErr = fmt.Errorf("timeout getting results after 3 seconds")
}
// Try to stop the program (may already be stopped by collectEvents)
fmt.Printf("🛑 Stopping eBPF program [%s]...\n", prog.Name)
stopErr := a.ebpfManager.StopProgram(programID)
if stopErr != nil {
fmt.Printf("⚠️ eBPF program [%s] cleanup: %v (may have already completed)\n", prog.Name, stopErr)
// Don't return here, we still want to process results if we got them
}
if resultErr != nil {
fmt.Printf("❌ Failed to get results for eBPF program [%s]: %v\n", prog.Name, resultErr)
result := map[string]interface{}{
"name": prog.Name,
"type": prog.Type,
"target": prog.Target,
"duration": int(prog.Duration),
"description": prog.Description,
"status": "collection_failed",
"error": resultErr.Error(),
"success": false,
}
results = append(results, result)
continue
} // Process the real eBPF trace data
result := map[string]interface{}{
"name": prog.Name,
"type": prog.Type,
"target": prog.Target,
"duration": int(prog.Duration),
"description": prog.Description,
"status": "completed",
"success": true,
}
// Extract real data from the trace
if trace != nil {
result["trace_id"] = trace.TraceID
result["data_points"] = trace.EventCount
result["events"] = trace.Events
result["summary"] = trace.Summary
result["process_list"] = trace.ProcessList
result["start_time"] = trace.StartTime.Format(time.RFC3339)
result["end_time"] = trace.EndTime.Format(time.RFC3339)
result["actual_duration"] = trace.EndTime.Sub(trace.StartTime).Seconds()
fmt.Printf("✅ eBPF program [%s] completed - collected %d real events\n", prog.Name, trace.EventCount)
} else {
result["data_points"] = 0
result["error"] = "No trace data returned"
fmt.Printf("⚠️ eBPF program [%s] completed but returned no trace data\n", prog.Name)
}
results = append(results, result)
}
return results
}
// TensorZeroRequest represents a request structure compatible with TensorZero's episode_id // TensorZeroRequest represents a request structure compatible with TensorZero's episode_id
type TensorZeroRequest struct { type TensorZeroRequest struct {
Model string `json:"model"` Model string `json:"model"`
@@ -193,6 +353,11 @@ type TensorZeroResponse struct {
// sendRequest sends a request to the TensorZero API via Supabase proxy with JWT authentication // sendRequest sends a request to the TensorZero API via Supabase proxy with JWT authentication
func (a *LinuxDiagnosticAgent) sendRequest(messages []openai.ChatCompletionMessage) (*openai.ChatCompletionResponse, error) { func (a *LinuxDiagnosticAgent) sendRequest(messages []openai.ChatCompletionMessage) (*openai.ChatCompletionResponse, error) {
return a.sendRequestWithEpisode(messages, "")
}
// sendRequestWithEpisode sends a request with a specific episode ID
func (a *LinuxDiagnosticAgent) sendRequestWithEpisode(messages []openai.ChatCompletionMessage, episodeID string) (*openai.ChatCompletionResponse, error) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel() defer cancel()
@@ -202,9 +367,12 @@ func (a *LinuxDiagnosticAgent) sendRequest(messages []openai.ChatCompletionMessa
Messages: messages, Messages: messages,
} }
// Include tensorzero::episode_id for conversation continuity (if we have one) // Include tensorzero::episode_id for conversation continuity
// Use agent's existing episode ID if available, otherwise use provided one
if a.episodeID != "" { if a.episodeID != "" {
tzRequest.EpisodeID = a.episodeID tzRequest.EpisodeID = a.episodeID
} else if episodeID != "" {
tzRequest.EpisodeID = episodeID
} }
fmt.Printf("Debug: Sending request to model: %s", a.model) fmt.Printf("Debug: Sending request to model: %s", a.model)

1
go.mod
View File

@@ -13,6 +13,7 @@ require (
require ( require (
github.com/go-ole/go-ole v1.2.6 // indirect github.com/go-ole/go-ole v1.2.6 // indirect
github.com/gorilla/websocket v1.5.3 // indirect
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
github.com/shoenig/go-m1cpu v0.1.6 // indirect github.com/shoenig/go-m1cpu v0.1.6 // indirect

2
go.sum
View File

@@ -9,6 +9,8 @@ github.com/go-quicktest/qt v1.101.1-0.20240301121107-c6c8733fa1e6/go.mod h1:p4lG
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA= github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA=

View File

@@ -2,6 +2,7 @@ package auth
import ( import (
"bytes" "bytes"
"encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@@ -395,6 +396,104 @@ func (am *AuthManager) loadRefreshTokenFromBackup() (string, error) {
return refreshToken, nil return refreshToken, nil
} }
// GetCurrentAgentID retrieves the agent ID from cache or JWT token
func (am *AuthManager) GetCurrentAgentID() (string, error) {
// First try to read from local cache
agentID, err := am.loadCachedAgentID()
if err == nil && agentID != "" {
return agentID, nil
}
// Cache miss - extract from JWT token and cache it
token, err := am.LoadToken()
if err != nil {
return "", fmt.Errorf("failed to load token: %w", err)
}
// Extract agent ID from JWT 'sub' field
agentID, err = am.extractAgentIDFromJWT(token.AccessToken)
if err != nil {
return "", fmt.Errorf("failed to extract agent ID from JWT: %w", err)
}
// Cache the agent ID for future use
if err := am.cacheAgentID(agentID); err != nil {
// Log warning but don't fail - we still have the agent ID
fmt.Printf("Warning: Failed to cache agent ID: %v\n", err)
}
return agentID, nil
}
// extractAgentIDFromJWT decodes the JWT token and extracts the agent ID from 'sub' field
func (am *AuthManager) extractAgentIDFromJWT(tokenString string) (string, error) {
// Basic JWT decoding without verification (since we trust Supabase)
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
return "", fmt.Errorf("invalid JWT token format")
}
// Decode the payload (second part)
payload := parts[1]
// Add padding if needed for base64 decoding
for len(payload)%4 != 0 {
payload += "="
}
decoded, err := base64.URLEncoding.DecodeString(payload)
if err != nil {
return "", fmt.Errorf("failed to decode JWT payload: %w", err)
}
// Parse JSON payload
var claims map[string]interface{}
if err := json.Unmarshal(decoded, &claims); err != nil {
return "", fmt.Errorf("failed to parse JWT claims: %w", err)
}
// The agent ID is in the 'sub' field (subject)
if agentID, ok := claims["sub"].(string); ok && agentID != "" {
return agentID, nil
}
return "", fmt.Errorf("agent ID (sub) not found in JWT claims")
}
// loadCachedAgentID reads the cached agent ID from local storage
func (am *AuthManager) loadCachedAgentID() (string, error) {
agentIDPath := filepath.Join(TokenStorageDir, "agent_id")
data, err := os.ReadFile(agentIDPath)
if err != nil {
return "", fmt.Errorf("failed to read cached agent ID: %w", err)
}
agentID := strings.TrimSpace(string(data))
if agentID == "" {
return "", fmt.Errorf("cached agent ID is empty")
}
return agentID, nil
}
// cacheAgentID stores the agent ID in local cache
func (am *AuthManager) cacheAgentID(agentID string) error {
// Ensure the directory exists
if err := am.EnsureTokenStorageDir(); err != nil {
return fmt.Errorf("failed to ensure storage directory: %w", err)
}
agentIDPath := filepath.Join(TokenStorageDir, "agent_id")
// Write agent ID to file with secure permissions
if err := os.WriteFile(agentIDPath, []byte(agentID), 0600); err != nil {
return fmt.Errorf("failed to write agent ID cache: %w", err)
}
return nil
}
func (am *AuthManager) getTokenPath() string { func (am *AuthManager) getTokenPath() string {
if am.config.TokenPath != "" { if am.config.TokenPath != "" {
return am.config.TokenPath return am.config.TokenPath

View File

@@ -242,6 +242,9 @@ func (c *Collector) CreateMetricsRequest(agentID string, systemMetrics *types.Sy
"load15": systemMetrics.LoadAvg15, "load15": systemMetrics.LoadAvg15,
}, },
OSInfo: map[string]string{ OSInfo: map[string]string{
"cpu_cores": fmt.Sprintf("%d", systemMetrics.CPUCores),
"memory": fmt.Sprintf("%.1fGi", float64(systemMetrics.MemoryTotal)/(1024*1024*1024)),
"uptime": "unknown", // Will be calculated by the server or client
"platform": systemMetrics.Platform, "platform": systemMetrics.Platform,
"platform_family": systemMetrics.PlatformFamily, "platform_family": systemMetrics.PlatformFamily,
"platform_version": systemMetrics.PlatformVersion, "platform_version": systemMetrics.PlatformVersion,

527
investigation_server.go Normal file
View File

@@ -0,0 +1,527 @@
package main
import (
"encoding/json"
"fmt"
"net/http"
"os"
"strings"
"time"
"nannyagentv2/internal/auth"
"nannyagentv2/internal/metrics"
"github.com/sashabaranov/go-openai"
)
// InvestigationRequest represents a request from Supabase to start an investigation
type InvestigationRequest struct {
InvestigationID string `json:"investigation_id"`
ApplicationGroup string `json:"application_group"`
Issue string `json:"issue"`
Context map[string]string `json:"context"`
Priority string `json:"priority"`
InitiatedBy string `json:"initiated_by"`
}
// InvestigationResponse represents the agent's response to an investigation
type InvestigationResponse struct {
AgentID string `json:"agent_id"`
InvestigationID string `json:"investigation_id"`
Status string `json:"status"`
Commands []CommandResult `json:"commands,omitempty"`
AIResponse string `json:"ai_response,omitempty"`
EpisodeID string `json:"episode_id,omitempty"`
Timestamp time.Time `json:"timestamp"`
Error string `json:"error,omitempty"`
}
// InvestigationServer handles reverse investigation requests from Supabase
type InvestigationServer struct {
agent *LinuxDiagnosticAgent // Original agent for direct user interactions
applicationAgent *LinuxDiagnosticAgent // Separate agent for application-initiated investigations
port string
agentID string
metricsCollector *metrics.Collector
authManager *auth.AuthManager
startTime time.Time
supabaseURL string
}
// NewInvestigationServer creates a new investigation server
func NewInvestigationServer(agent *LinuxDiagnosticAgent, authManager *auth.AuthManager) *InvestigationServer {
port := os.Getenv("AGENT_PORT")
if port == "" {
port = "1234"
}
// Get agent ID from authentication system
var agentID string
if authManager != nil {
if id, err := authManager.GetCurrentAgentID(); err == nil {
agentID = id
fmt.Printf("✅ Retrieved agent ID from auth manager: %s\n", agentID)
} else {
fmt.Printf("❌ Failed to get agent ID from auth manager: %v\n", err)
}
}
// Fallback to environment variable or generate one if auth fails
if agentID == "" {
agentID = os.Getenv("AGENT_ID")
if agentID == "" {
agentID = fmt.Sprintf("agent-%d", time.Now().Unix())
}
}
// Create metrics collector
metricsCollector := metrics.NewCollector("v2.0.0")
// Create a separate agent for application-initiated investigations
applicationAgent := NewLinuxDiagnosticAgent()
// Override the model to use the application-specific function
applicationAgent.model = "tensorzero::function_name::diagnose_and_heal_application"
return &InvestigationServer{
agent: agent,
applicationAgent: applicationAgent,
port: port,
agentID: agentID,
metricsCollector: metricsCollector,
authManager: authManager,
startTime: time.Now(),
supabaseURL: os.Getenv("SUPABASE_PROJECT_URL"),
}
}
// DiagnoseIssueForApplication handles diagnostic requests initiated from application/portal
func (s *InvestigationServer) DiagnoseIssueForApplication(issue, episodeID string) error {
// Set the episode ID on the application agent for continuity
s.applicationAgent.episodeID = episodeID
return s.applicationAgent.DiagnoseIssue(issue)
}
// Start starts the HTTP server and realtime polling for investigation requests
func (s *InvestigationServer) Start() error {
mux := http.NewServeMux()
// Health check endpoint
mux.HandleFunc("/health", s.handleHealth)
// Investigation endpoint
mux.HandleFunc("/investigate", s.handleInvestigation)
// Agent status endpoint
mux.HandleFunc("/status", s.handleStatus)
// Start realtime polling for backend-initiated investigations
if s.supabaseURL != "" && s.authManager != nil {
go s.startRealtimePolling()
fmt.Printf("🔄 Realtime investigation polling enabled\n")
} else {
fmt.Printf("⚠️ Realtime investigation polling disabled (missing Supabase config or auth)\n")
}
server := &http.Server{
Addr: ":" + s.port,
Handler: mux,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
}
fmt.Printf("🔍 Investigation server started on port %s (Agent ID: %s)\n", s.port, s.agentID)
return server.ListenAndServe()
}
// handleHealth responds to health check requests
func (s *InvestigationServer) handleHealth(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
response := map[string]interface{}{
"status": "healthy",
"agent_id": s.agentID,
"timestamp": time.Now(),
"version": "v2.0.0",
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}
// handleStatus responds with agent status and capabilities
func (s *InvestigationServer) handleStatus(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// Collect current system metrics
systemMetrics, err := s.metricsCollector.GatherSystemMetrics()
if err != nil {
http.Error(w, fmt.Sprintf("Failed to collect metrics: %v", err), http.StatusInternalServerError)
return
}
// Convert to metrics request format for consistent data structure
metricsReq := s.metricsCollector.CreateMetricsRequest(s.agentID, systemMetrics)
response := map[string]interface{}{
"agent_id": s.agentID,
"status": "ready",
"capabilities": []string{"system_diagnostics", "ebpf_monitoring", "command_execution", "ai_analysis"},
"system_info": map[string]interface{}{
"os": fmt.Sprintf("%s %s", metricsReq.OSInfo["platform"], metricsReq.OSInfo["platform_version"]),
"kernel": metricsReq.KernelVersion,
"architecture": metricsReq.OSInfo["kernel_arch"],
"cpu_cores": metricsReq.OSInfo["cpu_cores"],
"memory": metricsReq.MemoryUsage,
"private_ips": metricsReq.IPAddress,
"load_average": fmt.Sprintf("%.2f, %.2f, %.2f",
metricsReq.LoadAverages["load1"],
metricsReq.LoadAverages["load5"],
metricsReq.LoadAverages["load15"]),
"disk_usage": fmt.Sprintf("Root: %.0fG/%.0fG (%.0f%% used)",
float64(metricsReq.FilesystemInfo[0].Used)/1024/1024/1024,
float64(metricsReq.FilesystemInfo[0].Total)/1024/1024/1024,
metricsReq.DiskUsage),
},
"uptime": time.Since(s.startTime),
"last_contact": time.Now(),
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}
// sendCommandResultsToTensorZero sends command results back to TensorZero and continues conversation
func (s *InvestigationServer) sendCommandResultsToTensorZero(diagnosticResp DiagnosticResponse, commandResults []CommandResult) (interface{}, error) {
// Build conversation history like in agent.go
messages := []openai.ChatCompletionMessage{
// Add the original diagnostic response as assistant message
{
Role: openai.ChatMessageRoleAssistant,
Content: fmt.Sprintf(`{"response_type":"diagnostic","reasoning":"%s","commands":%s}`,
diagnosticResp.Reasoning,
mustMarshalJSON(diagnosticResp.Commands)),
},
}
// Add command results as user message (same as agent.go does)
resultsJSON, err := json.MarshalIndent(commandResults, "", " ")
if err != nil {
return nil, fmt.Errorf("failed to marshal command results: %w", err)
}
messages = append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleUser,
Content: string(resultsJSON),
})
// Send to TensorZero via application agent's sendRequest method
fmt.Printf("🔄 Sending command results to TensorZero for analysis...\n")
response, err := s.applicationAgent.sendRequest(messages)
if err != nil {
return nil, fmt.Errorf("failed to send request to TensorZero: %w", err)
}
if len(response.Choices) == 0 {
return nil, fmt.Errorf("no choices in TensorZero response")
}
content := response.Choices[0].Message.Content
fmt.Printf("🤖 TensorZero continued analysis:\n%s\n", content)
// Try to parse the response to determine if it's diagnostic or resolution
var diagnosticNextResp DiagnosticResponse
var resolutionResp ResolutionResponse
// Check if it's another diagnostic response
if err := json.Unmarshal([]byte(content), &diagnosticNextResp); err == nil && diagnosticNextResp.ResponseType == "diagnostic" {
fmt.Printf("🔄 TensorZero requests %d more commands\n", len(diagnosticNextResp.Commands))
return map[string]interface{}{
"type": "diagnostic",
"response": diagnosticNextResp,
"raw": content,
}, nil
}
// Check if it's a resolution response
if err := json.Unmarshal([]byte(content), &resolutionResp); err == nil && resolutionResp.ResponseType == "resolution" {
fmt.Printf("✅ TensorZero provided final resolution\n")
return map[string]interface{}{
"type": "resolution",
"response": resolutionResp,
"raw": content,
}, nil
}
// Return raw response if we can't parse it
return map[string]interface{}{
"type": "unknown",
"raw": content,
}, nil
}
// Helper function to marshal JSON without errors
func mustMarshalJSON(v interface{}) string {
data, _ := json.Marshal(v)
return string(data)
}
// processInvestigation handles the actual investigation using TensorZero
// This endpoint receives either:
// 1. DiagnosticResponse - Commands and eBPF programs to execute
// 2. ResolutionResponse - Final resolution (no execution needed)
func (s *InvestigationServer) handleInvestigation(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed - only POST accepted", http.StatusMethodNotAllowed)
return
}
// Parse the request body to determine what type of response this is
var requestBody map[string]interface{}
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
http.Error(w, fmt.Sprintf("Invalid JSON: %v", err), http.StatusBadRequest)
return
}
// Check the response_type field to determine how to handle this
responseType, ok := requestBody["response_type"].(string)
if !ok {
http.Error(w, "Missing or invalid response_type field", http.StatusBadRequest)
return
}
fmt.Printf("📋 Received investigation payload with response_type: %s\n", responseType)
switch responseType {
case "diagnostic":
// This is a DiagnosticResponse with commands to execute
response := s.handleDiagnosticExecution(requestBody)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
case "resolution":
// This is a ResolutionResponse - final result, just acknowledge
fmt.Printf("📋 Received final resolution from backend\n")
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"success": true,
"message": "Resolution received and acknowledged",
"agent_id": s.agentID,
})
default:
http.Error(w, fmt.Sprintf("Unknown response_type: %s", responseType), http.StatusBadRequest)
return
}
}
// handleDiagnosticExecution executes commands from a DiagnosticResponse
func (s *InvestigationServer) handleDiagnosticExecution(requestBody map[string]interface{}) map[string]interface{} {
// Parse as DiagnosticResponse
var diagnosticResp DiagnosticResponse
// Convert the map back to JSON and then parse it properly
jsonData, err := json.Marshal(requestBody)
if err != nil {
return map[string]interface{}{
"success": false,
"error": fmt.Sprintf("Failed to re-marshal request: %v", err),
"agent_id": s.agentID,
}
}
if err := json.Unmarshal(jsonData, &diagnosticResp); err != nil {
return map[string]interface{}{
"success": false,
"error": fmt.Sprintf("Failed to parse DiagnosticResponse: %v", err),
"agent_id": s.agentID,
}
}
fmt.Printf("📋 Executing %d commands from backend\n", len(diagnosticResp.Commands))
// Execute all commands
commandResults := make([]CommandResult, 0, len(diagnosticResp.Commands))
for _, cmd := range diagnosticResp.Commands {
fmt.Printf("⚙️ Executing command '%s': %s\n", cmd.ID, cmd.Command)
// Use the agent's executor to run the command
result := s.agent.executor.Execute(cmd)
commandResults = append(commandResults, result)
fmt.Printf("✅ Command '%s' completed with exit code %d\n", cmd.ID, result.ExitCode)
if result.Error != "" {
fmt.Printf("⚠️ Command '%s' had error: %s\n", cmd.ID, result.Error)
}
}
// Send command results back to TensorZero for continued analysis
fmt.Printf("🔄 Sending %d command results back to TensorZero for continued analysis\n", len(commandResults))
nextResponse, err := s.sendCommandResultsToTensorZero(diagnosticResp, commandResults)
if err != nil {
return map[string]interface{}{
"success": false,
"error": fmt.Sprintf("Failed to continue TensorZero conversation: %v", err),
"agent_id": s.agentID,
"command_results": commandResults, // Still return the results
}
}
// Return both the command results and the next response from TensorZero
return map[string]interface{}{
"success": true,
"agent_id": s.agentID,
"command_results": commandResults,
"commands_executed": len(commandResults),
"next_response": nextResponse,
"timestamp": time.Now().Format(time.RFC3339),
}
}
// PendingInvestigation represents a pending investigation from the database
type PendingInvestigation struct {
ID string `json:"id"`
InvestigationID string `json:"investigation_id"`
AgentID string `json:"agent_id"`
DiagnosticPayload map[string]interface{} `json:"diagnostic_payload"`
EpisodeID *string `json:"episode_id"`
Status string `json:"status"`
CreatedAt time.Time `json:"created_at"`
}
// startRealtimePolling begins polling for pending investigations
func (s *InvestigationServer) startRealtimePolling() {
fmt.Printf("🔄 Starting realtime investigation polling for agent %s\n", s.agentID)
ticker := time.NewTicker(5 * time.Second) // Poll every 5 seconds
defer ticker.Stop()
for range ticker.C {
s.checkForPendingInvestigations()
}
}
// checkForPendingInvestigations checks for new pending investigations
func (s *InvestigationServer) checkForPendingInvestigations() {
url := fmt.Sprintf("%s/rest/v1/pending_investigations?agent_id=eq.%s&status=eq.pending&order=created_at.desc",
s.supabaseURL, s.agentID)
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return // Silent fail for polling
}
// Get token from auth manager
authToken, err := s.authManager.LoadToken()
if err != nil {
return // Silent fail for polling
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", authToken.AccessToken))
req.Header.Set("Accept", "application/json")
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return // Silent fail for polling
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
return // Silent fail for polling
}
var investigations []PendingInvestigation
err = json.NewDecoder(resp.Body).Decode(&investigations)
if err != nil {
return // Silent fail for polling
}
for _, investigation := range investigations {
fmt.Printf("🔍 Found pending investigation: %s\n", investigation.ID)
go s.handlePendingInvestigation(investigation)
}
}
// handlePendingInvestigation processes a single pending investigation
func (s *InvestigationServer) handlePendingInvestigation(investigation PendingInvestigation) {
fmt.Printf("🚀 Processing realtime investigation %s\n", investigation.InvestigationID)
// Mark as executing
err := s.updateInvestigationStatus(investigation.ID, "executing", nil, nil)
if err != nil {
fmt.Printf("❌ Failed to mark investigation as executing: %v\n", err)
return
}
// Execute diagnostic commands using existing handleDiagnosticExecution method
results := s.handleDiagnosticExecution(investigation.DiagnosticPayload)
// Mark as completed with results
err = s.updateInvestigationStatus(investigation.ID, "completed", results, nil)
if err != nil {
fmt.Printf("❌ Failed to mark investigation as completed: %v\n", err)
return
}
fmt.Printf("✅ Realtime investigation %s completed successfully\n", investigation.InvestigationID)
}
// updateInvestigationStatus updates the status of a pending investigation
func (s *InvestigationServer) updateInvestigationStatus(id, status string, results map[string]interface{}, errorMsg *string) error {
updateData := map[string]interface{}{
"status": status,
}
if status == "executing" {
updateData["started_at"] = time.Now().UTC().Format(time.RFC3339)
} else if status == "completed" {
updateData["completed_at"] = time.Now().UTC().Format(time.RFC3339)
if results != nil {
updateData["command_results"] = results
}
} else if status == "failed" && errorMsg != nil {
updateData["error_message"] = *errorMsg
updateData["completed_at"] = time.Now().UTC().Format(time.RFC3339)
}
jsonData, err := json.Marshal(updateData)
if err != nil {
return fmt.Errorf("failed to marshal update data: %v", err)
}
url := fmt.Sprintf("%s/rest/v1/pending_investigations?id=eq.%s", s.supabaseURL, id)
req, err := http.NewRequest("PATCH", url, strings.NewReader(string(jsonData)))
if err != nil {
return fmt.Errorf("failed to create request: %v", err)
}
// Get token from auth manager
authToken, err := s.authManager.LoadToken()
if err != nil {
return fmt.Errorf("failed to load auth token: %v", err)
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", authToken.AccessToken))
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("failed to update investigation: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 && resp.StatusCode != 204 {
return fmt.Errorf("supabase update error: %d", resp.StatusCode)
}
return nil
}

14
main.go
View File

@@ -168,9 +168,21 @@ func main() {
fmt.Println("✅ Authentication successful!") fmt.Println("✅ Authentication successful!")
// Initialize the diagnostic agent // Initialize the diagnostic agent for interactive CLI use
agent := NewLinuxDiagnosticAgent() agent := NewLinuxDiagnosticAgent()
// Initialize a separate agent for WebSocket investigations using the application model
applicationAgent := NewLinuxDiagnosticAgent()
applicationAgent.model = "tensorzero::function_name::diagnose_and_heal_application"
// Start WebSocket client for backend communications and investigations
wsClient := NewWebSocketClient(applicationAgent, authManager)
go func() {
if err := wsClient.Start(); err != nil {
log.Printf("❌ WebSocket client error: %v", err)
}
}()
// Start background metrics collection in a goroutine // Start background metrics collection in a goroutine
go func() { go func() {
fmt.Println("❤️ Starting background metrics collection and heartbeat...") fmt.Println("❤️ Starting background metrics collection and heartbeat...")

863
websocket_client.go Normal file
View File

@@ -0,0 +1,863 @@
package main
import (
"context"
"encoding/json"
"fmt"
"log"
"net"
"net/http"
"os"
"os/exec"
"strings"
"time"
"nannyagentv2/internal/auth"
"nannyagentv2/internal/metrics"
"github.com/gorilla/websocket"
"github.com/sashabaranov/go-openai"
)
// Helper function for minimum of two integers
// WebSocketMessage represents a message sent over WebSocket
type WebSocketMessage struct {
Type string `json:"type"`
Data interface{} `json:"data"`
}
// InvestigationTask represents a task sent to the agent
type InvestigationTask struct {
TaskID string `json:"task_id"`
InvestigationID string `json:"investigation_id"`
AgentID string `json:"agent_id"`
DiagnosticPayload map[string]interface{} `json:"diagnostic_payload"`
EpisodeID string `json:"episode_id,omitempty"`
}
// TaskResult represents the result of a completed task
type TaskResult struct {
TaskID string `json:"task_id"`
Success bool `json:"success"`
CommandResults map[string]interface{} `json:"command_results,omitempty"`
Error string `json:"error,omitempty"`
}
// HeartbeatData represents heartbeat information
type HeartbeatData struct {
AgentID string `json:"agent_id"`
Timestamp time.Time `json:"timestamp"`
Version string `json:"version"`
}
// WebSocketClient handles WebSocket connection to Supabase backend
type WebSocketClient struct {
agent *LinuxDiagnosticAgent
conn *websocket.Conn
agentID string
authManager *auth.AuthManager
metricsCollector *metrics.Collector
supabaseURL string
token string
ctx context.Context
cancel context.CancelFunc
consecutiveFailures int // Track consecutive connection failures
}
// NewWebSocketClient creates a new WebSocket client
func NewWebSocketClient(agent *LinuxDiagnosticAgent, authManager *auth.AuthManager) *WebSocketClient {
// Get agent ID from authentication system
var agentID string
if authManager != nil {
if id, err := authManager.GetCurrentAgentID(); err == nil {
agentID = id
// Agent ID retrieved successfully
} else {
fmt.Printf("❌ Failed to get agent ID from auth manager: %v\n", err)
}
}
// Fallback to environment variable or generate one if auth fails
if agentID == "" {
agentID = os.Getenv("AGENT_ID")
if agentID == "" {
agentID = fmt.Sprintf("agent-%d", time.Now().Unix())
}
}
supabaseURL := os.Getenv("SUPABASE_PROJECT_URL")
if supabaseURL == "" {
log.Fatal("❌ SUPABASE_PROJECT_URL environment variable is required")
}
// Create metrics collector
metricsCollector := metrics.NewCollector("v2.0.0")
ctx, cancel := context.WithCancel(context.Background())
return &WebSocketClient{
agent: agent,
agentID: agentID,
authManager: authManager,
metricsCollector: metricsCollector,
supabaseURL: supabaseURL,
ctx: ctx,
cancel: cancel,
}
}
// Start starts the WebSocket connection and message handling
func (w *WebSocketClient) Start() error {
// Starting WebSocket client
if err := w.connect(); err != nil {
return fmt.Errorf("failed to establish WebSocket connection: %v", err)
}
// Start message reading loop
go w.handleMessages()
// Start heartbeat
go w.startHeartbeat()
// Start database polling for pending investigations
go w.pollPendingInvestigations()
// WebSocket client started
return nil
}
// Stop closes the WebSocket connection
func (c *WebSocketClient) Stop() {
fmt.Println("🛑 Stopping WebSocket client...")
c.cancel()
if c.conn != nil {
c.conn.Close()
}
}
// getAuthToken retrieves authentication token
func (c *WebSocketClient) getAuthToken() error {
if c.authManager == nil {
return fmt.Errorf("auth manager not available")
}
token, err := c.authManager.EnsureAuthenticated()
if err != nil {
return fmt.Errorf("authentication failed: %v", err)
}
c.token = token.AccessToken
return nil
}
// connect establishes WebSocket connection
func (c *WebSocketClient) connect() error {
// Get fresh auth token
if err := c.getAuthToken(); err != nil {
return fmt.Errorf("failed to get auth token: %v", err)
}
// Convert HTTP URL to WebSocket URL
wsURL := strings.Replace(c.supabaseURL, "https://", "wss://", 1)
wsURL = strings.Replace(wsURL, "http://", "ws://", 1)
wsURL += "/functions/v1/websocket-agent-handler"
// Connecting to WebSocket
// Set up headers
headers := http.Header{}
headers.Set("Authorization", "Bearer "+c.token)
// Connect
dialer := websocket.Dialer{
HandshakeTimeout: 10 * time.Second,
}
conn, resp, err := dialer.Dial(wsURL, headers)
if err != nil {
c.consecutiveFailures++
if c.consecutiveFailures >= 5 && resp != nil {
fmt.Printf("❌ WebSocket handshake failed with status: %d (failure #%d)\n", resp.StatusCode, c.consecutiveFailures)
}
return fmt.Errorf("websocket connection failed: %v", err)
}
c.conn = conn
// WebSocket client connected
return nil
}
// handleMessages processes incoming WebSocket messages
func (c *WebSocketClient) handleMessages() {
defer func() {
if c.conn != nil {
// Closing WebSocket connection
c.conn.Close()
}
}()
// Started WebSocket message listener
connectionStart := time.Now()
for {
select {
case <-c.ctx.Done():
// Only log context cancellation if there have been failures
if c.consecutiveFailures >= 5 {
fmt.Printf("📡 Context cancelled after %v, stopping message handler\n", time.Since(connectionStart))
}
return
default:
// Set read deadline to detect connection issues
c.conn.SetReadDeadline(time.Now().Add(90 * time.Second))
var message WebSocketMessage
readStart := time.Now()
err := c.conn.ReadJSON(&message)
readDuration := time.Since(readStart)
if err != nil {
connectionDuration := time.Since(connectionStart)
// Only log specific errors after failure threshold
if c.consecutiveFailures >= 5 {
if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
log.Printf("🔒 WebSocket closed normally after %v: %v", connectionDuration, err)
} else if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
log.Printf("💥 ABNORMAL CLOSE after %v (code 1006 = server-side timeout/kill): %v", connectionDuration, err)
log.Printf("🕒 Last read took %v, connection lived %v", readDuration, connectionDuration)
} else if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
log.Printf("⏰ READ TIMEOUT after %v: %v", connectionDuration, err)
} else {
log.Printf("❌ WebSocket error after %v: %v", connectionDuration, err)
}
}
// Track consecutive failures for diagnostic threshold
c.consecutiveFailures++
// Only show diagnostics after multiple failures
if c.consecutiveFailures >= 5 {
log.Printf("🔍 DIAGNOSTIC - Connection failed #%d after %v", c.consecutiveFailures, connectionDuration)
}
// Attempt reconnection instead of returning immediately
go c.attemptReconnection()
return
}
// Received WebSocket message successfully - reset failure counter
c.consecutiveFailures = 0
switch message.Type {
case "connection_ack":
// Connection acknowledged
case "heartbeat_ack":
// Heartbeat acknowledged
case "investigation_task":
// Received investigation task - processing
go c.handleInvestigationTask(message.Data)
case "task_result_ack":
// Task result acknowledged
default:
log.Printf("⚠️ Unknown message type: %s", message.Type)
}
}
}
}
// handleInvestigationTask processes investigation tasks from the backend
func (c *WebSocketClient) handleInvestigationTask(data interface{}) {
// Parse task data
taskBytes, err := json.Marshal(data)
if err != nil {
log.Printf("❌ Error marshaling task data: %v", err)
return
}
var task InvestigationTask
err = json.Unmarshal(taskBytes, &task)
if err != nil {
log.Printf("❌ Error unmarshaling investigation task: %v", err)
return
}
// Processing investigation task
// Execute diagnostic commands
results, err := c.executeDiagnosticCommands(task.DiagnosticPayload)
// Prepare task result
taskResult := TaskResult{
TaskID: task.TaskID,
Success: err == nil,
}
if err != nil {
taskResult.Error = err.Error()
fmt.Printf("❌ Task execution failed: %v\n", err)
} else {
taskResult.CommandResults = results
// Task executed successfully
}
// Send result back
c.sendTaskResult(taskResult)
}
// executeDiagnosticCommands executes the commands from a diagnostic response
func (c *WebSocketClient) executeDiagnosticCommands(diagnosticPayload map[string]interface{}) (map[string]interface{}, error) {
fmt.Println("🔧 Executing diagnostic commands...")
results := map[string]interface{}{
"agent_id": c.agentID,
"execution_time": time.Now().UTC().Format(time.RFC3339),
"command_results": []map[string]interface{}{},
}
// Extract commands from diagnostic payload
commands, ok := diagnosticPayload["commands"].([]interface{})
if !ok {
return nil, fmt.Errorf("no commands found in diagnostic payload")
}
var commandResults []map[string]interface{}
for _, cmd := range commands {
cmdMap, ok := cmd.(map[string]interface{})
if !ok {
continue
}
id, _ := cmdMap["id"].(string)
command, _ := cmdMap["command"].(string)
description, _ := cmdMap["description"].(string)
if command == "" {
continue
}
// Executing command
// Execute the command
output, exitCode, err := c.executeCommand(command)
result := map[string]interface{}{
"id": id,
"command": command,
"description": description,
"output": output,
"exit_code": exitCode,
"success": err == nil && exitCode == 0,
}
if err != nil {
result["error"] = err.Error()
fmt.Printf("❌ Command [%s] failed: %v (exit code: %d)\n", id, err, exitCode)
} else {
// Command completed successfully - output captured
}
commandResults = append(commandResults, result)
}
results["command_results"] = commandResults
results["total_commands"] = len(commandResults)
results["successful_commands"] = c.countSuccessfulCommands(commandResults)
// Execute eBPF programs if present
ebpfPrograms, hasEBPF := diagnosticPayload["ebpf_programs"].([]interface{})
if hasEBPF && len(ebpfPrograms) > 0 {
fmt.Printf("🔬 Executing %d eBPF programs...\n", len(ebpfPrograms))
ebpfResults := c.executeEBPFPrograms(ebpfPrograms)
results["ebpf_results"] = ebpfResults
results["total_ebpf_programs"] = len(ebpfPrograms)
} else {
fmt.Printf(" No eBPF programs in diagnostic payload\n")
}
fmt.Printf("✅ Executed %d commands, %d successful\n",
results["total_commands"], results["successful_commands"])
return results, nil
}
// executeEBPFPrograms executes eBPF monitoring programs using the real eBPF manager
func (c *WebSocketClient) executeEBPFPrograms(ebpfPrograms []interface{}) []map[string]interface{} {
var ebpfRequests []EBPFRequest
// Convert interface{} to EBPFRequest structs
for _, prog := range ebpfPrograms {
progMap, ok := prog.(map[string]interface{})
if !ok {
continue
}
name, _ := progMap["name"].(string)
progType, _ := progMap["type"].(string)
target, _ := progMap["target"].(string)
duration, _ := progMap["duration"].(float64)
description, _ := progMap["description"].(string)
if name == "" || progType == "" || target == "" {
continue
}
ebpfRequests = append(ebpfRequests, EBPFRequest{
Name: name,
Type: progType,
Target: target,
Duration: int(duration),
Description: description,
})
}
// Execute eBPF programs using the agent's eBPF execution logic
return c.agent.executeEBPFPrograms(ebpfRequests)
}
// executeCommandsFromPayload executes commands from a payload and returns results
func (c *WebSocketClient) executeCommandsFromPayload(commands []interface{}) []map[string]interface{} {
var commandResults []map[string]interface{}
for _, cmd := range commands {
cmdMap, ok := cmd.(map[string]interface{})
if !ok {
continue
}
id, _ := cmdMap["id"].(string)
command, _ := cmdMap["command"].(string)
description, _ := cmdMap["description"].(string)
if command == "" {
continue
}
// Execute the command
output, exitCode, err := c.executeCommand(command)
result := map[string]interface{}{
"id": id,
"command": command,
"description": description,
"output": output,
"exit_code": exitCode,
"success": err == nil && exitCode == 0,
}
if err != nil {
result["error"] = err.Error()
fmt.Printf("❌ Command [%s] failed: %v (exit code: %d)\n", id, err, exitCode)
} else {
fmt.Printf("✅ Command [%s] completed successfully\n", id)
}
commandResults = append(commandResults, result)
}
return commandResults
}
// executeCommand executes a shell command and returns output, exit code, and error
func (c *WebSocketClient) executeCommand(command string) (string, int, error) {
// Parse command into parts
parts := strings.Fields(command)
if len(parts) == 0 {
return "", -1, fmt.Errorf("empty command")
}
// Create command with timeout
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, parts[0], parts[1:]...)
cmd.Env = os.Environ()
output, err := cmd.CombinedOutput()
exitCode := 0
if err != nil {
if exitError, ok := err.(*exec.ExitError); ok {
exitCode = exitError.ExitCode()
} else {
exitCode = -1
}
}
return string(output), exitCode, err
}
// countSuccessfulCommands counts the number of successful commands
func (c *WebSocketClient) countSuccessfulCommands(results []map[string]interface{}) int {
count := 0
for _, result := range results {
if success, ok := result["success"].(bool); ok && success {
count++
}
}
return count
}
// sendTaskResult sends a task result back to the backend
func (c *WebSocketClient) sendTaskResult(result TaskResult) {
message := WebSocketMessage{
Type: "task_result",
Data: result,
}
err := c.conn.WriteJSON(message)
if err != nil {
log.Printf("❌ Error sending task result: %v", err)
}
}
// startHeartbeat sends periodic heartbeat messages
func (c *WebSocketClient) startHeartbeat() {
ticker := time.NewTicker(30 * time.Second) // Heartbeat every 30 seconds
defer ticker.Stop()
// Starting heartbeat
for {
select {
case <-c.ctx.Done():
fmt.Printf("💓 Heartbeat stopped due to context cancellation\n")
return
case <-ticker.C:
// Sending heartbeat
heartbeat := WebSocketMessage{
Type: "heartbeat",
Data: HeartbeatData{
AgentID: c.agentID,
Timestamp: time.Now(),
Version: "v2.0.0",
},
}
err := c.conn.WriteJSON(heartbeat)
if err != nil {
log.Printf("❌ Error sending heartbeat: %v", err)
fmt.Printf("💓 Heartbeat failed, connection likely dead\n")
return
}
// Heartbeat sent
}
}
}
// pollPendingInvestigations polls the database for pending investigations
func (c *WebSocketClient) pollPendingInvestigations() {
// Starting database polling
ticker := time.NewTicker(5 * time.Second) // Poll every 5 seconds
defer ticker.Stop()
for {
select {
case <-c.ctx.Done():
return
case <-ticker.C:
c.checkForPendingInvestigations()
}
}
}
// checkForPendingInvestigations checks the database for new pending investigations via proxy
func (c *WebSocketClient) checkForPendingInvestigations() {
// Use Edge Function proxy instead of direct database access
url := fmt.Sprintf("%s/functions/v1/agent-database-proxy/pending-investigations", c.supabaseURL)
// Poll database for pending investigations
req, err := http.NewRequest("GET", url, nil)
if err != nil {
// Request creation failed
return
}
// Only JWT token needed for proxy - no API keys exposed
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.token))
req.Header.Set("Accept", "application/json")
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
// Database request failed
return
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
return
}
var investigations []PendingInvestigation
err = json.NewDecoder(resp.Body).Decode(&investigations)
if err != nil {
// Response decode failed
return
}
for _, investigation := range investigations {
go c.handlePendingInvestigation(investigation)
}
}
// handlePendingInvestigation processes a pending investigation from database polling
func (c *WebSocketClient) handlePendingInvestigation(investigation PendingInvestigation) {
// Processing pending investigation
// Mark as executing
err := c.updateInvestigationStatus(investigation.ID, "executing", nil, nil)
if err != nil {
return
}
// Execute diagnostic commands
results, err := c.executeDiagnosticCommands(investigation.DiagnosticPayload)
// Prepare the base results map we'll send to DB
resultsForDB := map[string]interface{}{
"agent_id": c.agentID,
"execution_time": time.Now().UTC().Format(time.RFC3339),
"command_results": results,
}
// If command execution failed, mark investigation as failed
if err != nil {
errorMsg := err.Error()
// Include partial results when possible
if results != nil {
resultsForDB["command_results"] = results
}
c.updateInvestigationStatus(investigation.ID, "failed", resultsForDB, &errorMsg)
// Investigation failed
return
}
// Try to continue the TensorZero conversation by sending command results back
// Build messages: assistant = diagnostic payload, user = command results
diagJSON, _ := json.Marshal(investigation.DiagnosticPayload)
commandsJSON, _ := json.MarshalIndent(results, "", " ")
messages := []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleAssistant,
Content: string(diagJSON),
},
{
Role: openai.ChatMessageRoleUser,
Content: string(commandsJSON),
},
}
fmt.Printf("🔄 Sending command results to TensorZero for continued analysis...\n")
fmt.Printf("📤 Command results payload size: %d bytes\n", len(commandsJSON))
// Use the episode ID from the investigation to maintain conversation continuity
episodeID := ""
if investigation.EpisodeID != nil {
episodeID = *investigation.EpisodeID
fmt.Printf("🔗 Using episode ID: %s\n", episodeID)
}
// Continue conversation until resolution (same as agent)
var finalAIContent string
for {
tzResp, tzErr := c.agent.sendRequestWithEpisode(messages, episodeID)
if tzErr != nil {
fmt.Printf("⚠️ TensorZero continuation failed: %v\n", tzErr)
// Fall back to marking completed with command results only
c.updateInvestigationStatus(investigation.ID, "completed", resultsForDB, nil)
return
}
fmt.Printf("✅ TensorZero responded successfully\n")
if len(tzResp.Choices) == 0 {
fmt.Printf("⚠️ No choices in TensorZero response\n")
c.updateInvestigationStatus(investigation.ID, "completed", resultsForDB, nil)
return
}
aiContent := tzResp.Choices[0].Message.Content
if len(aiContent) > 300 {
fmt.Printf("🤖 AI Response preview: %s...\n", aiContent[:300])
} else {
fmt.Printf("🤖 AI Response: %s\n", aiContent)
}
// Check if this is a resolution response (final)
var resolutionResp struct {
ResponseType string `json:"response_type"`
RootCause string `json:"root_cause"`
ResolutionPlan string `json:"resolution_plan"`
Confidence string `json:"confidence"`
}
fmt.Printf("🔍 Analyzing AI response type...\n")
if err := json.Unmarshal([]byte(aiContent), &resolutionResp); err == nil && resolutionResp.ResponseType == "resolution" {
// This is the final resolution - show summary and complete
fmt.Printf("✅ Detected RESOLUTION response - completing investigation\n")
fmt.Printf("\n=== DIAGNOSIS COMPLETE ===\n")
fmt.Printf("Root Cause: %s\n", resolutionResp.RootCause)
fmt.Printf("Resolution Plan: %s\n", resolutionResp.ResolutionPlan)
fmt.Printf("Confidence: %s\n", resolutionResp.Confidence)
finalAIContent = aiContent
break
}
// Check if this is another diagnostic response requiring more commands
var diagnosticResp struct {
ResponseType string `json:"response_type"`
Commands []interface{} `json:"commands"`
EBPFPrograms []interface{} `json:"ebpf_programs"`
}
if err := json.Unmarshal([]byte(aiContent), &diagnosticResp); err == nil && diagnosticResp.ResponseType == "diagnostic" {
fmt.Printf("✅ Detected DIAGNOSTIC response - continuing conversation\n")
fmt.Printf("🔄 AI requested additional diagnostics, executing...\n")
// Execute additional commands if any
additionalResults := map[string]interface{}{
"command_results": []map[string]interface{}{},
}
if len(diagnosticResp.Commands) > 0 {
fmt.Printf("🔧 Executing %d additional diagnostic commands...\n", len(diagnosticResp.Commands))
commandResults := c.executeCommandsFromPayload(diagnosticResp.Commands)
additionalResults["command_results"] = commandResults
}
// Execute additional eBPF programs if any
if len(diagnosticResp.EBPFPrograms) > 0 {
fmt.Printf("🔬 Executing %d additional eBPF programs...\n", len(diagnosticResp.EBPFPrograms))
ebpfResults := c.executeEBPFPrograms(diagnosticResp.EBPFPrograms)
additionalResults["ebpf_results"] = ebpfResults
}
// Add AI response and additional results to conversation
messages = append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleAssistant,
Content: aiContent,
})
additionalResultsJSON, _ := json.MarshalIndent(additionalResults, "", " ")
messages = append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleUser,
Content: string(additionalResultsJSON),
})
continue
}
// If neither resolution nor diagnostic, treat as final response
fmt.Printf("⚠️ Unknown response type - treating as final response\n")
finalAIContent = aiContent
break
}
// Attach final AI response to results for DB and mark as completed_with_analysis
resultsForDB["ai_response"] = finalAIContent
fmt.Printf("💾 Updating database with results and AI analysis...\n")
c.updateInvestigationStatus(investigation.ID, "completed_with_analysis", resultsForDB, nil)
fmt.Printf("✅ Investigation completed with AI analysis\n")
}
// updateInvestigationStatus updates the status of a pending investigation
func (c *WebSocketClient) updateInvestigationStatus(id, status string, results map[string]interface{}, errorMsg *string) error {
updateData := map[string]interface{}{
"status": status,
}
if status == "executing" {
updateData["started_at"] = time.Now().UTC().Format(time.RFC3339)
} else if status == "completed" {
updateData["completed_at"] = time.Now().UTC().Format(time.RFC3339)
if results != nil {
updateData["command_results"] = results
}
} else if status == "failed" && errorMsg != nil {
updateData["error_message"] = *errorMsg
updateData["completed_at"] = time.Now().UTC().Format(time.RFC3339)
}
jsonData, err := json.Marshal(updateData)
if err != nil {
return fmt.Errorf("failed to marshal update data: %v", err)
}
url := fmt.Sprintf("%s/functions/v1/agent-database-proxy/pending-investigations/%s", c.supabaseURL, id)
req, err := http.NewRequest("PATCH", url, strings.NewReader(string(jsonData)))
if err != nil {
return fmt.Errorf("failed to create request: %v", err)
}
// Only JWT token needed for proxy - no API keys exposed
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.token))
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("failed to update investigation: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 && resp.StatusCode != 204 {
return fmt.Errorf("supabase update error: %d", resp.StatusCode)
}
return nil
}
// attemptReconnection attempts to reconnect the WebSocket with backoff
func (c *WebSocketClient) attemptReconnection() {
backoffDurations := []time.Duration{
2 * time.Second,
5 * time.Second,
10 * time.Second,
20 * time.Second,
30 * time.Second,
}
for i, backoff := range backoffDurations {
select {
case <-c.ctx.Done():
return
default:
c.consecutiveFailures++
// Only show messages after 5 consecutive failures
if c.consecutiveFailures >= 5 {
log.Printf("🔄 Attempting WebSocket reconnection (attempt %d/%d) - %d consecutive failures", i+1, len(backoffDurations), c.consecutiveFailures)
}
time.Sleep(backoff)
if err := c.connect(); err != nil {
if c.consecutiveFailures >= 5 {
log.Printf("❌ Reconnection attempt %d failed: %v", i+1, err)
}
continue
}
// Successfully reconnected - reset failure counter
if c.consecutiveFailures >= 5 {
log.Printf("✅ WebSocket reconnected successfully after %d failures", c.consecutiveFailures)
}
c.consecutiveFailures = 0
go c.handleMessages() // Restart message handling
return
}
}
log.Printf("❌ Failed to reconnect after %d attempts, giving up", len(backoffDurations))
}