diff --git a/agent.go b/agent.go index 1c9a64d..67ecc9d 100644 --- a/agent.go +++ b/agent.go @@ -102,7 +102,7 @@ func (a *LinuxDiagnosticAgent) DiagnoseIssue(issue string) error { for { // Send request to TensorZero API via OpenAI SDK - response, err := a.sendRequest(messages) + response, err := a.sendRequestWithEpisode(messages, a.episodeID) if err != nil { return fmt.Errorf("failed to send request: %w", err) } @@ -115,34 +115,73 @@ func (a *LinuxDiagnosticAgent) DiagnoseIssue(issue string) error { fmt.Printf("\nAI Response:\n%s\n", content) // Parse the response to determine next action - var diagnosticResp DiagnosticResponse + var diagnosticResp EBPFEnhancedDiagnosticResponse var resolutionResp ResolutionResponse - // Try to parse as diagnostic response first + // Try to parse as diagnostic response first (with eBPF support) if err := json.Unmarshal([]byte(content), &diagnosticResp); err == nil && diagnosticResp.ResponseType == "diagnostic" { // Handle diagnostic phase fmt.Printf("\nReasoning: %s\n", diagnosticResp.Reasoning) - if len(diagnosticResp.Commands) == 0 { - fmt.Println("No commands to execute in diagnostic phase") - break - } - // Execute commands and collect results commandResults := make([]CommandResult, 0, len(diagnosticResp.Commands)) - for _, cmd := range diagnosticResp.Commands { - fmt.Printf("\nExecuting command '%s': %s\n", cmd.ID, cmd.Command) - result := a.executor.Execute(cmd) - commandResults = append(commandResults, result) + if len(diagnosticResp.Commands) > 0 { + fmt.Printf("🔧 Executing diagnostic commands...\n") + for _, cmd := range diagnosticResp.Commands { + fmt.Printf("âš™ī¸ Executing command '%s': %s\n", cmd.ID, cmd.Command) + result := a.executor.Execute(cmd) + commandResults = append(commandResults, result) - fmt.Printf("Output:\n%s\n", result.Output) - if result.Error != "" { - fmt.Printf("Error: %s\n", result.Error) + if result.ExitCode == 0 { + fmt.Printf("✅ Command '%s' completed successfully\n", cmd.ID) + } else { + fmt.Printf("❌ Command '%s' failed with exit code %d\n", cmd.ID, result.ExitCode) + } } } - // Prepare command results as user message - resultsJSON, err := json.MarshalIndent(commandResults, "", " ") + // Execute eBPF programs if present + var ebpfResults []map[string]interface{} + if len(diagnosticResp.EBPFPrograms) > 0 { + fmt.Printf("đŸ”Ŧ Executing %d eBPF programs...\n", len(diagnosticResp.EBPFPrograms)) + ebpfResults = a.executeEBPFPrograms(diagnosticResp.EBPFPrograms) + } + + // Prepare combined results as user message + allResults := map[string]interface{}{ + "command_results": commandResults, + "executed_commands": len(commandResults), + } + + // Include eBPF results if any were executed + if len(ebpfResults) > 0 { + allResults["ebpf_results"] = ebpfResults + allResults["executed_ebpf_programs"] = len(ebpfResults) + + // Extract evidence summary for TensorZero + evidenceSummary := make([]string, 0) + for _, result := range ebpfResults { + name := result["name"] + eventCount := result["data_points"] + description := result["description"] + status := result["status"] + + summaryStr := fmt.Sprintf("%s: %v events (%s) - %s", name, eventCount, status, description) + evidenceSummary = append(evidenceSummary, summaryStr) + } + allResults["ebpf_evidence_summary"] = evidenceSummary + + fmt.Printf("īŋŊ Sending eBPF monitoring data to TensorZero:\n") + for _, summary := range evidenceSummary { + fmt.Printf(" - %s\n", summary) + } + + fmt.Printf("✅ Executed %d commands, %d eBPF programs\n", len(commandResults), len(ebpfResults)) + } else { + fmt.Printf("✅ Executed %d commands\n", len(commandResults)) + } + + resultsJSON, err := json.MarshalIndent(allResults, "", " ") if err != nil { return fmt.Errorf("failed to marshal command results: %w", err) } @@ -178,6 +217,127 @@ func (a *LinuxDiagnosticAgent) DiagnoseIssue(issue string) error { return nil } +// executeEBPFPrograms executes REAL eBPF monitoring programs using the actual eBPF manager +func (a *LinuxDiagnosticAgent) executeEBPFPrograms(ebpfPrograms []EBPFRequest) []map[string]interface{} { + var results []map[string]interface{} + + if a.ebpfManager == nil { + fmt.Printf("❌ eBPF manager not initialized\n") + return results + } + + for _, prog := range ebpfPrograms { + fmt.Printf("đŸ”Ŧ Starting eBPF program [%s]: %s -> %s (%ds)\n", prog.Name, prog.Type, prog.Target, int(prog.Duration)) + + // Actually start the eBPF program using the real manager + programID, err := a.ebpfManager.StartEBPFProgram(prog) + if err != nil { + fmt.Printf("❌ Failed to start eBPF program [%s]: %v\n", prog.Name, err) + result := map[string]interface{}{ + "name": prog.Name, + "type": prog.Type, + "target": prog.Target, + "duration": int(prog.Duration), + "description": prog.Description, + "status": "failed", + "error": err.Error(), + "success": false, + } + results = append(results, result) + continue + } + + // Let the eBPF program run for the specified duration + fmt.Printf("⏰ Waiting %d seconds for eBPF program to collect data...\n", int(prog.Duration)) + time.Sleep(time.Duration(prog.Duration) * time.Second) + + // Give the collectEvents goroutine a moment to finish and store results + fmt.Printf("âŗ Allowing program to complete data collection...\n") + time.Sleep(500 * time.Millisecond) + + // Get the results (should be in completedResults now) + fmt.Printf("📊 Getting results for eBPF program [%s]...\n", prog.Name) + + // Use a channel to implement timeout for GetProgramResults + type resultPair struct { + trace *EBPFTrace + err error + } + resultChan := make(chan resultPair, 1) + + go func() { + trace, err := a.ebpfManager.GetProgramResults(programID) + resultChan <- resultPair{trace, err} + }() + + var trace *EBPFTrace + var resultErr error + + select { + case result := <-resultChan: + trace = result.trace + resultErr = result.err + case <-time.After(3 * time.Second): + resultErr = fmt.Errorf("timeout getting results after 3 seconds") + } + + // Try to stop the program (may already be stopped by collectEvents) + fmt.Printf("🛑 Stopping eBPF program [%s]...\n", prog.Name) + stopErr := a.ebpfManager.StopProgram(programID) + if stopErr != nil { + fmt.Printf("âš ī¸ eBPF program [%s] cleanup: %v (may have already completed)\n", prog.Name, stopErr) + // Don't return here, we still want to process results if we got them + } + + if resultErr != nil { + fmt.Printf("❌ Failed to get results for eBPF program [%s]: %v\n", prog.Name, resultErr) + result := map[string]interface{}{ + "name": prog.Name, + "type": prog.Type, + "target": prog.Target, + "duration": int(prog.Duration), + "description": prog.Description, + "status": "collection_failed", + "error": resultErr.Error(), + "success": false, + } + results = append(results, result) + continue + } // Process the real eBPF trace data + result := map[string]interface{}{ + "name": prog.Name, + "type": prog.Type, + "target": prog.Target, + "duration": int(prog.Duration), + "description": prog.Description, + "status": "completed", + "success": true, + } + + // Extract real data from the trace + if trace != nil { + result["trace_id"] = trace.TraceID + result["data_points"] = trace.EventCount + result["events"] = trace.Events + result["summary"] = trace.Summary + result["process_list"] = trace.ProcessList + result["start_time"] = trace.StartTime.Format(time.RFC3339) + result["end_time"] = trace.EndTime.Format(time.RFC3339) + result["actual_duration"] = trace.EndTime.Sub(trace.StartTime).Seconds() + + fmt.Printf("✅ eBPF program [%s] completed - collected %d real events\n", prog.Name, trace.EventCount) + } else { + result["data_points"] = 0 + result["error"] = "No trace data returned" + fmt.Printf("âš ī¸ eBPF program [%s] completed but returned no trace data\n", prog.Name) + } + + results = append(results, result) + } + + return results +} + // TensorZeroRequest represents a request structure compatible with TensorZero's episode_id type TensorZeroRequest struct { Model string `json:"model"` @@ -193,6 +353,11 @@ type TensorZeroResponse struct { // sendRequest sends a request to the TensorZero API via Supabase proxy with JWT authentication func (a *LinuxDiagnosticAgent) sendRequest(messages []openai.ChatCompletionMessage) (*openai.ChatCompletionResponse, error) { + return a.sendRequestWithEpisode(messages, "") +} + +// sendRequestWithEpisode sends a request with a specific episode ID +func (a *LinuxDiagnosticAgent) sendRequestWithEpisode(messages []openai.ChatCompletionMessage, episodeID string) (*openai.ChatCompletionResponse, error) { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() @@ -202,9 +367,12 @@ func (a *LinuxDiagnosticAgent) sendRequest(messages []openai.ChatCompletionMessa Messages: messages, } - // Include tensorzero::episode_id for conversation continuity (if we have one) + // Include tensorzero::episode_id for conversation continuity + // Use agent's existing episode ID if available, otherwise use provided one if a.episodeID != "" { tzRequest.EpisodeID = a.episodeID + } else if episodeID != "" { + tzRequest.EpisodeID = episodeID } fmt.Printf("Debug: Sending request to model: %s", a.model) diff --git a/go.mod b/go.mod index 11af360..22f8b94 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( require ( github.com/go-ole/go-ole v1.2.6 // indirect + github.com/gorilla/websocket v1.5.3 // indirect github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect github.com/shoenig/go-m1cpu v0.1.6 // indirect diff --git a/go.sum b/go.sum index 451412b..24815e2 100644 --- a/go.sum +++ b/go.sum @@ -9,6 +9,8 @@ github.com/go-quicktest/qt v1.101.1-0.20240301121107-c6c8733fa1e6/go.mod h1:p4lG github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA= diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 8c9af45..5a2bbff 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -2,6 +2,7 @@ package auth import ( "bytes" + "encoding/base64" "encoding/json" "fmt" "io" @@ -395,6 +396,104 @@ func (am *AuthManager) loadRefreshTokenFromBackup() (string, error) { return refreshToken, nil } +// GetCurrentAgentID retrieves the agent ID from cache or JWT token +func (am *AuthManager) GetCurrentAgentID() (string, error) { + // First try to read from local cache + agentID, err := am.loadCachedAgentID() + if err == nil && agentID != "" { + return agentID, nil + } + + // Cache miss - extract from JWT token and cache it + token, err := am.LoadToken() + if err != nil { + return "", fmt.Errorf("failed to load token: %w", err) + } + + // Extract agent ID from JWT 'sub' field + agentID, err = am.extractAgentIDFromJWT(token.AccessToken) + if err != nil { + return "", fmt.Errorf("failed to extract agent ID from JWT: %w", err) + } + + // Cache the agent ID for future use + if err := am.cacheAgentID(agentID); err != nil { + // Log warning but don't fail - we still have the agent ID + fmt.Printf("Warning: Failed to cache agent ID: %v\n", err) + } + + return agentID, nil +} + +// extractAgentIDFromJWT decodes the JWT token and extracts the agent ID from 'sub' field +func (am *AuthManager) extractAgentIDFromJWT(tokenString string) (string, error) { + // Basic JWT decoding without verification (since we trust Supabase) + parts := strings.Split(tokenString, ".") + if len(parts) != 3 { + return "", fmt.Errorf("invalid JWT token format") + } + + // Decode the payload (second part) + payload := parts[1] + + // Add padding if needed for base64 decoding + for len(payload)%4 != 0 { + payload += "=" + } + + decoded, err := base64.URLEncoding.DecodeString(payload) + if err != nil { + return "", fmt.Errorf("failed to decode JWT payload: %w", err) + } + + // Parse JSON payload + var claims map[string]interface{} + if err := json.Unmarshal(decoded, &claims); err != nil { + return "", fmt.Errorf("failed to parse JWT claims: %w", err) + } + + // The agent ID is in the 'sub' field (subject) + if agentID, ok := claims["sub"].(string); ok && agentID != "" { + return agentID, nil + } + + return "", fmt.Errorf("agent ID (sub) not found in JWT claims") +} + +// loadCachedAgentID reads the cached agent ID from local storage +func (am *AuthManager) loadCachedAgentID() (string, error) { + agentIDPath := filepath.Join(TokenStorageDir, "agent_id") + + data, err := os.ReadFile(agentIDPath) + if err != nil { + return "", fmt.Errorf("failed to read cached agent ID: %w", err) + } + + agentID := strings.TrimSpace(string(data)) + if agentID == "" { + return "", fmt.Errorf("cached agent ID is empty") + } + + return agentID, nil +} + +// cacheAgentID stores the agent ID in local cache +func (am *AuthManager) cacheAgentID(agentID string) error { + // Ensure the directory exists + if err := am.EnsureTokenStorageDir(); err != nil { + return fmt.Errorf("failed to ensure storage directory: %w", err) + } + + agentIDPath := filepath.Join(TokenStorageDir, "agent_id") + + // Write agent ID to file with secure permissions + if err := os.WriteFile(agentIDPath, []byte(agentID), 0600); err != nil { + return fmt.Errorf("failed to write agent ID cache: %w", err) + } + + return nil +} + func (am *AuthManager) getTokenPath() string { if am.config.TokenPath != "" { return am.config.TokenPath diff --git a/internal/metrics/collector.go b/internal/metrics/collector.go index fa01e40..1b587fd 100644 --- a/internal/metrics/collector.go +++ b/internal/metrics/collector.go @@ -242,6 +242,9 @@ func (c *Collector) CreateMetricsRequest(agentID string, systemMetrics *types.Sy "load15": systemMetrics.LoadAvg15, }, OSInfo: map[string]string{ + "cpu_cores": fmt.Sprintf("%d", systemMetrics.CPUCores), + "memory": fmt.Sprintf("%.1fGi", float64(systemMetrics.MemoryTotal)/(1024*1024*1024)), + "uptime": "unknown", // Will be calculated by the server or client "platform": systemMetrics.Platform, "platform_family": systemMetrics.PlatformFamily, "platform_version": systemMetrics.PlatformVersion, diff --git a/investigation_server.go b/investigation_server.go new file mode 100644 index 0000000..dee5701 --- /dev/null +++ b/investigation_server.go @@ -0,0 +1,527 @@ +package main + +import ( + "encoding/json" + "fmt" + "net/http" + "os" + "strings" + "time" + + "nannyagentv2/internal/auth" + "nannyagentv2/internal/metrics" + + "github.com/sashabaranov/go-openai" +) + +// InvestigationRequest represents a request from Supabase to start an investigation +type InvestigationRequest struct { + InvestigationID string `json:"investigation_id"` + ApplicationGroup string `json:"application_group"` + Issue string `json:"issue"` + Context map[string]string `json:"context"` + Priority string `json:"priority"` + InitiatedBy string `json:"initiated_by"` +} + +// InvestigationResponse represents the agent's response to an investigation +type InvestigationResponse struct { + AgentID string `json:"agent_id"` + InvestigationID string `json:"investigation_id"` + Status string `json:"status"` + Commands []CommandResult `json:"commands,omitempty"` + AIResponse string `json:"ai_response,omitempty"` + EpisodeID string `json:"episode_id,omitempty"` + Timestamp time.Time `json:"timestamp"` + Error string `json:"error,omitempty"` +} + +// InvestigationServer handles reverse investigation requests from Supabase +type InvestigationServer struct { + agent *LinuxDiagnosticAgent // Original agent for direct user interactions + applicationAgent *LinuxDiagnosticAgent // Separate agent for application-initiated investigations + port string + agentID string + metricsCollector *metrics.Collector + authManager *auth.AuthManager + startTime time.Time + supabaseURL string +} + +// NewInvestigationServer creates a new investigation server +func NewInvestigationServer(agent *LinuxDiagnosticAgent, authManager *auth.AuthManager) *InvestigationServer { + port := os.Getenv("AGENT_PORT") + if port == "" { + port = "1234" + } + + // Get agent ID from authentication system + var agentID string + if authManager != nil { + if id, err := authManager.GetCurrentAgentID(); err == nil { + agentID = id + fmt.Printf("✅ Retrieved agent ID from auth manager: %s\n", agentID) + } else { + fmt.Printf("❌ Failed to get agent ID from auth manager: %v\n", err) + } + } + + // Fallback to environment variable or generate one if auth fails + if agentID == "" { + agentID = os.Getenv("AGENT_ID") + if agentID == "" { + agentID = fmt.Sprintf("agent-%d", time.Now().Unix()) + } + } + + // Create metrics collector + metricsCollector := metrics.NewCollector("v2.0.0") + + // Create a separate agent for application-initiated investigations + applicationAgent := NewLinuxDiagnosticAgent() + // Override the model to use the application-specific function + applicationAgent.model = "tensorzero::function_name::diagnose_and_heal_application" + + return &InvestigationServer{ + agent: agent, + applicationAgent: applicationAgent, + port: port, + agentID: agentID, + metricsCollector: metricsCollector, + authManager: authManager, + startTime: time.Now(), + supabaseURL: os.Getenv("SUPABASE_PROJECT_URL"), + } +} + +// DiagnoseIssueForApplication handles diagnostic requests initiated from application/portal +func (s *InvestigationServer) DiagnoseIssueForApplication(issue, episodeID string) error { + // Set the episode ID on the application agent for continuity + s.applicationAgent.episodeID = episodeID + return s.applicationAgent.DiagnoseIssue(issue) +} + +// Start starts the HTTP server and realtime polling for investigation requests +func (s *InvestigationServer) Start() error { + mux := http.NewServeMux() + + // Health check endpoint + mux.HandleFunc("/health", s.handleHealth) + + // Investigation endpoint + mux.HandleFunc("/investigate", s.handleInvestigation) + + // Agent status endpoint + mux.HandleFunc("/status", s.handleStatus) + + // Start realtime polling for backend-initiated investigations + if s.supabaseURL != "" && s.authManager != nil { + go s.startRealtimePolling() + fmt.Printf("🔄 Realtime investigation polling enabled\n") + } else { + fmt.Printf("âš ī¸ Realtime investigation polling disabled (missing Supabase config or auth)\n") + } + + server := &http.Server{ + Addr: ":" + s.port, + Handler: mux, + ReadTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, + } + + fmt.Printf("🔍 Investigation server started on port %s (Agent ID: %s)\n", s.port, s.agentID) + return server.ListenAndServe() +} + +// handleHealth responds to health check requests +func (s *InvestigationServer) handleHealth(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + response := map[string]interface{}{ + "status": "healthy", + "agent_id": s.agentID, + "timestamp": time.Now(), + "version": "v2.0.0", + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) +} + +// handleStatus responds with agent status and capabilities +func (s *InvestigationServer) handleStatus(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Collect current system metrics + systemMetrics, err := s.metricsCollector.GatherSystemMetrics() + if err != nil { + http.Error(w, fmt.Sprintf("Failed to collect metrics: %v", err), http.StatusInternalServerError) + return + } + + // Convert to metrics request format for consistent data structure + metricsReq := s.metricsCollector.CreateMetricsRequest(s.agentID, systemMetrics) + + response := map[string]interface{}{ + "agent_id": s.agentID, + "status": "ready", + "capabilities": []string{"system_diagnostics", "ebpf_monitoring", "command_execution", "ai_analysis"}, + "system_info": map[string]interface{}{ + "os": fmt.Sprintf("%s %s", metricsReq.OSInfo["platform"], metricsReq.OSInfo["platform_version"]), + "kernel": metricsReq.KernelVersion, + "architecture": metricsReq.OSInfo["kernel_arch"], + "cpu_cores": metricsReq.OSInfo["cpu_cores"], + "memory": metricsReq.MemoryUsage, + "private_ips": metricsReq.IPAddress, + "load_average": fmt.Sprintf("%.2f, %.2f, %.2f", + metricsReq.LoadAverages["load1"], + metricsReq.LoadAverages["load5"], + metricsReq.LoadAverages["load15"]), + "disk_usage": fmt.Sprintf("Root: %.0fG/%.0fG (%.0f%% used)", + float64(metricsReq.FilesystemInfo[0].Used)/1024/1024/1024, + float64(metricsReq.FilesystemInfo[0].Total)/1024/1024/1024, + metricsReq.DiskUsage), + }, + "uptime": time.Since(s.startTime), + "last_contact": time.Now(), + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) +} + +// sendCommandResultsToTensorZero sends command results back to TensorZero and continues conversation +func (s *InvestigationServer) sendCommandResultsToTensorZero(diagnosticResp DiagnosticResponse, commandResults []CommandResult) (interface{}, error) { + // Build conversation history like in agent.go + messages := []openai.ChatCompletionMessage{ + // Add the original diagnostic response as assistant message + { + Role: openai.ChatMessageRoleAssistant, + Content: fmt.Sprintf(`{"response_type":"diagnostic","reasoning":"%s","commands":%s}`, + diagnosticResp.Reasoning, + mustMarshalJSON(diagnosticResp.Commands)), + }, + } + + // Add command results as user message (same as agent.go does) + resultsJSON, err := json.MarshalIndent(commandResults, "", " ") + if err != nil { + return nil, fmt.Errorf("failed to marshal command results: %w", err) + } + + messages = append(messages, openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleUser, + Content: string(resultsJSON), + }) + + // Send to TensorZero via application agent's sendRequest method + fmt.Printf("🔄 Sending command results to TensorZero for analysis...\n") + response, err := s.applicationAgent.sendRequest(messages) + if err != nil { + return nil, fmt.Errorf("failed to send request to TensorZero: %w", err) + } + + if len(response.Choices) == 0 { + return nil, fmt.Errorf("no choices in TensorZero response") + } + + content := response.Choices[0].Message.Content + fmt.Printf("🤖 TensorZero continued analysis:\n%s\n", content) + + // Try to parse the response to determine if it's diagnostic or resolution + var diagnosticNextResp DiagnosticResponse + var resolutionResp ResolutionResponse + + // Check if it's another diagnostic response + if err := json.Unmarshal([]byte(content), &diagnosticNextResp); err == nil && diagnosticNextResp.ResponseType == "diagnostic" { + fmt.Printf("🔄 TensorZero requests %d more commands\n", len(diagnosticNextResp.Commands)) + return map[string]interface{}{ + "type": "diagnostic", + "response": diagnosticNextResp, + "raw": content, + }, nil + } + + // Check if it's a resolution response + if err := json.Unmarshal([]byte(content), &resolutionResp); err == nil && resolutionResp.ResponseType == "resolution" { + fmt.Printf("✅ TensorZero provided final resolution\n") + return map[string]interface{}{ + "type": "resolution", + "response": resolutionResp, + "raw": content, + }, nil + } + + // Return raw response if we can't parse it + return map[string]interface{}{ + "type": "unknown", + "raw": content, + }, nil +} + +// Helper function to marshal JSON without errors +func mustMarshalJSON(v interface{}) string { + data, _ := json.Marshal(v) + return string(data) +} + +// processInvestigation handles the actual investigation using TensorZero +// This endpoint receives either: +// 1. DiagnosticResponse - Commands and eBPF programs to execute +// 2. ResolutionResponse - Final resolution (no execution needed) +func (s *InvestigationServer) handleInvestigation(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed - only POST accepted", http.StatusMethodNotAllowed) + return + } + + // Parse the request body to determine what type of response this is + var requestBody map[string]interface{} + if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil { + http.Error(w, fmt.Sprintf("Invalid JSON: %v", err), http.StatusBadRequest) + return + } + + // Check the response_type field to determine how to handle this + responseType, ok := requestBody["response_type"].(string) + if !ok { + http.Error(w, "Missing or invalid response_type field", http.StatusBadRequest) + return + } + + fmt.Printf("📋 Received investigation payload with response_type: %s\n", responseType) + + switch responseType { + case "diagnostic": + // This is a DiagnosticResponse with commands to execute + response := s.handleDiagnosticExecution(requestBody) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) + + case "resolution": + // This is a ResolutionResponse - final result, just acknowledge + fmt.Printf("📋 Received final resolution from backend\n") + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "success": true, + "message": "Resolution received and acknowledged", + "agent_id": s.agentID, + }) + + default: + http.Error(w, fmt.Sprintf("Unknown response_type: %s", responseType), http.StatusBadRequest) + return + } +} + +// handleDiagnosticExecution executes commands from a DiagnosticResponse +func (s *InvestigationServer) handleDiagnosticExecution(requestBody map[string]interface{}) map[string]interface{} { + // Parse as DiagnosticResponse + var diagnosticResp DiagnosticResponse + + // Convert the map back to JSON and then parse it properly + jsonData, err := json.Marshal(requestBody) + if err != nil { + return map[string]interface{}{ + "success": false, + "error": fmt.Sprintf("Failed to re-marshal request: %v", err), + "agent_id": s.agentID, + } + } + + if err := json.Unmarshal(jsonData, &diagnosticResp); err != nil { + return map[string]interface{}{ + "success": false, + "error": fmt.Sprintf("Failed to parse DiagnosticResponse: %v", err), + "agent_id": s.agentID, + } + } + + fmt.Printf("📋 Executing %d commands from backend\n", len(diagnosticResp.Commands)) + + // Execute all commands + commandResults := make([]CommandResult, 0, len(diagnosticResp.Commands)) + + for _, cmd := range diagnosticResp.Commands { + fmt.Printf("âš™ī¸ Executing command '%s': %s\n", cmd.ID, cmd.Command) + + // Use the agent's executor to run the command + result := s.agent.executor.Execute(cmd) + commandResults = append(commandResults, result) + + fmt.Printf("✅ Command '%s' completed with exit code %d\n", cmd.ID, result.ExitCode) + if result.Error != "" { + fmt.Printf("âš ī¸ Command '%s' had error: %s\n", cmd.ID, result.Error) + } + } + + // Send command results back to TensorZero for continued analysis + fmt.Printf("🔄 Sending %d command results back to TensorZero for continued analysis\n", len(commandResults)) + + nextResponse, err := s.sendCommandResultsToTensorZero(diagnosticResp, commandResults) + if err != nil { + return map[string]interface{}{ + "success": false, + "error": fmt.Sprintf("Failed to continue TensorZero conversation: %v", err), + "agent_id": s.agentID, + "command_results": commandResults, // Still return the results + } + } + + // Return both the command results and the next response from TensorZero + return map[string]interface{}{ + "success": true, + "agent_id": s.agentID, + "command_results": commandResults, + "commands_executed": len(commandResults), + "next_response": nextResponse, + "timestamp": time.Now().Format(time.RFC3339), + } +} + +// PendingInvestigation represents a pending investigation from the database +type PendingInvestigation struct { + ID string `json:"id"` + InvestigationID string `json:"investigation_id"` + AgentID string `json:"agent_id"` + DiagnosticPayload map[string]interface{} `json:"diagnostic_payload"` + EpisodeID *string `json:"episode_id"` + Status string `json:"status"` + CreatedAt time.Time `json:"created_at"` +} + +// startRealtimePolling begins polling for pending investigations +func (s *InvestigationServer) startRealtimePolling() { + fmt.Printf("🔄 Starting realtime investigation polling for agent %s\n", s.agentID) + + ticker := time.NewTicker(5 * time.Second) // Poll every 5 seconds + defer ticker.Stop() + + for range ticker.C { + s.checkForPendingInvestigations() + } +} + +// checkForPendingInvestigations checks for new pending investigations +func (s *InvestigationServer) checkForPendingInvestigations() { + url := fmt.Sprintf("%s/rest/v1/pending_investigations?agent_id=eq.%s&status=eq.pending&order=created_at.desc", + s.supabaseURL, s.agentID) + + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return // Silent fail for polling + } + + // Get token from auth manager + authToken, err := s.authManager.LoadToken() + if err != nil { + return // Silent fail for polling + } + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", authToken.AccessToken)) + req.Header.Set("Accept", "application/json") + + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Do(req) + if err != nil { + return // Silent fail for polling + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + return // Silent fail for polling + } + + var investigations []PendingInvestigation + err = json.NewDecoder(resp.Body).Decode(&investigations) + if err != nil { + return // Silent fail for polling + } + + for _, investigation := range investigations { + fmt.Printf("🔍 Found pending investigation: %s\n", investigation.ID) + go s.handlePendingInvestigation(investigation) + } +} + +// handlePendingInvestigation processes a single pending investigation +func (s *InvestigationServer) handlePendingInvestigation(investigation PendingInvestigation) { + fmt.Printf("🚀 Processing realtime investigation %s\n", investigation.InvestigationID) + + // Mark as executing + err := s.updateInvestigationStatus(investigation.ID, "executing", nil, nil) + if err != nil { + fmt.Printf("❌ Failed to mark investigation as executing: %v\n", err) + return + } + + // Execute diagnostic commands using existing handleDiagnosticExecution method + results := s.handleDiagnosticExecution(investigation.DiagnosticPayload) + + // Mark as completed with results + err = s.updateInvestigationStatus(investigation.ID, "completed", results, nil) + if err != nil { + fmt.Printf("❌ Failed to mark investigation as completed: %v\n", err) + return + } + + fmt.Printf("✅ Realtime investigation %s completed successfully\n", investigation.InvestigationID) +} + +// updateInvestigationStatus updates the status of a pending investigation +func (s *InvestigationServer) updateInvestigationStatus(id, status string, results map[string]interface{}, errorMsg *string) error { + updateData := map[string]interface{}{ + "status": status, + } + + if status == "executing" { + updateData["started_at"] = time.Now().UTC().Format(time.RFC3339) + } else if status == "completed" { + updateData["completed_at"] = time.Now().UTC().Format(time.RFC3339) + if results != nil { + updateData["command_results"] = results + } + } else if status == "failed" && errorMsg != nil { + updateData["error_message"] = *errorMsg + updateData["completed_at"] = time.Now().UTC().Format(time.RFC3339) + } + + jsonData, err := json.Marshal(updateData) + if err != nil { + return fmt.Errorf("failed to marshal update data: %v", err) + } + + url := fmt.Sprintf("%s/rest/v1/pending_investigations?id=eq.%s", s.supabaseURL, id) + req, err := http.NewRequest("PATCH", url, strings.NewReader(string(jsonData))) + if err != nil { + return fmt.Errorf("failed to create request: %v", err) + } + + // Get token from auth manager + authToken, err := s.authManager.LoadToken() + if err != nil { + return fmt.Errorf("failed to load auth token: %v", err) + } + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", authToken.AccessToken)) + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("failed to update investigation: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != 200 && resp.StatusCode != 204 { + return fmt.Errorf("supabase update error: %d", resp.StatusCode) + } + + return nil +} diff --git a/main.go b/main.go index 47e8b64..8186c2f 100644 --- a/main.go +++ b/main.go @@ -168,9 +168,21 @@ func main() { fmt.Println("✅ Authentication successful!") - // Initialize the diagnostic agent + // Initialize the diagnostic agent for interactive CLI use agent := NewLinuxDiagnosticAgent() + // Initialize a separate agent for WebSocket investigations using the application model + applicationAgent := NewLinuxDiagnosticAgent() + applicationAgent.model = "tensorzero::function_name::diagnose_and_heal_application" + + // Start WebSocket client for backend communications and investigations + wsClient := NewWebSocketClient(applicationAgent, authManager) + go func() { + if err := wsClient.Start(); err != nil { + log.Printf("❌ WebSocket client error: %v", err) + } + }() + // Start background metrics collection in a goroutine go func() { fmt.Println("â¤ī¸ Starting background metrics collection and heartbeat...") diff --git a/websocket_client.go b/websocket_client.go new file mode 100644 index 0000000..e3fd9fa --- /dev/null +++ b/websocket_client.go @@ -0,0 +1,863 @@ +package main + +import ( + "context" + "encoding/json" + "fmt" + "log" + "net" + "net/http" + "os" + "os/exec" + "strings" + "time" + + "nannyagentv2/internal/auth" + "nannyagentv2/internal/metrics" + + "github.com/gorilla/websocket" + "github.com/sashabaranov/go-openai" +) + +// Helper function for minimum of two integers + +// WebSocketMessage represents a message sent over WebSocket +type WebSocketMessage struct { + Type string `json:"type"` + Data interface{} `json:"data"` +} + +// InvestigationTask represents a task sent to the agent +type InvestigationTask struct { + TaskID string `json:"task_id"` + InvestigationID string `json:"investigation_id"` + AgentID string `json:"agent_id"` + DiagnosticPayload map[string]interface{} `json:"diagnostic_payload"` + EpisodeID string `json:"episode_id,omitempty"` +} + +// TaskResult represents the result of a completed task +type TaskResult struct { + TaskID string `json:"task_id"` + Success bool `json:"success"` + CommandResults map[string]interface{} `json:"command_results,omitempty"` + Error string `json:"error,omitempty"` +} + +// HeartbeatData represents heartbeat information +type HeartbeatData struct { + AgentID string `json:"agent_id"` + Timestamp time.Time `json:"timestamp"` + Version string `json:"version"` +} + +// WebSocketClient handles WebSocket connection to Supabase backend +type WebSocketClient struct { + agent *LinuxDiagnosticAgent + conn *websocket.Conn + agentID string + authManager *auth.AuthManager + metricsCollector *metrics.Collector + supabaseURL string + token string + ctx context.Context + cancel context.CancelFunc + consecutiveFailures int // Track consecutive connection failures +} + +// NewWebSocketClient creates a new WebSocket client +func NewWebSocketClient(agent *LinuxDiagnosticAgent, authManager *auth.AuthManager) *WebSocketClient { + // Get agent ID from authentication system + var agentID string + if authManager != nil { + if id, err := authManager.GetCurrentAgentID(); err == nil { + agentID = id + // Agent ID retrieved successfully + } else { + fmt.Printf("❌ Failed to get agent ID from auth manager: %v\n", err) + } + } + + // Fallback to environment variable or generate one if auth fails + if agentID == "" { + agentID = os.Getenv("AGENT_ID") + if agentID == "" { + agentID = fmt.Sprintf("agent-%d", time.Now().Unix()) + } + } + + supabaseURL := os.Getenv("SUPABASE_PROJECT_URL") + if supabaseURL == "" { + log.Fatal("❌ SUPABASE_PROJECT_URL environment variable is required") + } + + // Create metrics collector + metricsCollector := metrics.NewCollector("v2.0.0") + + ctx, cancel := context.WithCancel(context.Background()) + + return &WebSocketClient{ + agent: agent, + agentID: agentID, + authManager: authManager, + metricsCollector: metricsCollector, + supabaseURL: supabaseURL, + ctx: ctx, + cancel: cancel, + } +} + +// Start starts the WebSocket connection and message handling +func (w *WebSocketClient) Start() error { + // Starting WebSocket client + + if err := w.connect(); err != nil { + return fmt.Errorf("failed to establish WebSocket connection: %v", err) + } + + // Start message reading loop + go w.handleMessages() + + // Start heartbeat + go w.startHeartbeat() + + // Start database polling for pending investigations + go w.pollPendingInvestigations() + + // WebSocket client started + return nil +} + +// Stop closes the WebSocket connection +func (c *WebSocketClient) Stop() { + fmt.Println("🛑 Stopping WebSocket client...") + c.cancel() + if c.conn != nil { + c.conn.Close() + } +} + +// getAuthToken retrieves authentication token +func (c *WebSocketClient) getAuthToken() error { + if c.authManager == nil { + return fmt.Errorf("auth manager not available") + } + + token, err := c.authManager.EnsureAuthenticated() + if err != nil { + return fmt.Errorf("authentication failed: %v", err) + } + + c.token = token.AccessToken + return nil +} + +// connect establishes WebSocket connection +func (c *WebSocketClient) connect() error { + // Get fresh auth token + if err := c.getAuthToken(); err != nil { + return fmt.Errorf("failed to get auth token: %v", err) + } + + // Convert HTTP URL to WebSocket URL + wsURL := strings.Replace(c.supabaseURL, "https://", "wss://", 1) + wsURL = strings.Replace(wsURL, "http://", "ws://", 1) + wsURL += "/functions/v1/websocket-agent-handler" + + // Connecting to WebSocket + + // Set up headers + headers := http.Header{} + headers.Set("Authorization", "Bearer "+c.token) + + // Connect + dialer := websocket.Dialer{ + HandshakeTimeout: 10 * time.Second, + } + + conn, resp, err := dialer.Dial(wsURL, headers) + if err != nil { + c.consecutiveFailures++ + if c.consecutiveFailures >= 5 && resp != nil { + fmt.Printf("❌ WebSocket handshake failed with status: %d (failure #%d)\n", resp.StatusCode, c.consecutiveFailures) + } + return fmt.Errorf("websocket connection failed: %v", err) + } + + c.conn = conn + // WebSocket client connected + return nil +} + +// handleMessages processes incoming WebSocket messages +func (c *WebSocketClient) handleMessages() { + defer func() { + if c.conn != nil { + // Closing WebSocket connection + c.conn.Close() + } + }() + + // Started WebSocket message listener + connectionStart := time.Now() + + for { + select { + case <-c.ctx.Done(): + // Only log context cancellation if there have been failures + if c.consecutiveFailures >= 5 { + fmt.Printf("📡 Context cancelled after %v, stopping message handler\n", time.Since(connectionStart)) + } + return + default: + // Set read deadline to detect connection issues + c.conn.SetReadDeadline(time.Now().Add(90 * time.Second)) + + var message WebSocketMessage + readStart := time.Now() + err := c.conn.ReadJSON(&message) + readDuration := time.Since(readStart) + + if err != nil { + connectionDuration := time.Since(connectionStart) + + // Only log specific errors after failure threshold + if c.consecutiveFailures >= 5 { + if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { + log.Printf("🔒 WebSocket closed normally after %v: %v", connectionDuration, err) + } else if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { + log.Printf("đŸ’Ĩ ABNORMAL CLOSE after %v (code 1006 = server-side timeout/kill): %v", connectionDuration, err) + log.Printf("🕒 Last read took %v, connection lived %v", readDuration, connectionDuration) + } else if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + log.Printf("⏰ READ TIMEOUT after %v: %v", connectionDuration, err) + } else { + log.Printf("❌ WebSocket error after %v: %v", connectionDuration, err) + } + } + + // Track consecutive failures for diagnostic threshold + c.consecutiveFailures++ + + // Only show diagnostics after multiple failures + if c.consecutiveFailures >= 5 { + log.Printf("🔍 DIAGNOSTIC - Connection failed #%d after %v", c.consecutiveFailures, connectionDuration) + } + + // Attempt reconnection instead of returning immediately + go c.attemptReconnection() + return + } + + // Received WebSocket message successfully - reset failure counter + c.consecutiveFailures = 0 + + switch message.Type { + case "connection_ack": + // Connection acknowledged + + case "heartbeat_ack": + // Heartbeat acknowledged + + case "investigation_task": + // Received investigation task - processing + go c.handleInvestigationTask(message.Data) + + case "task_result_ack": + // Task result acknowledged + + default: + log.Printf("âš ī¸ Unknown message type: %s", message.Type) + } + } + } +} + +// handleInvestigationTask processes investigation tasks from the backend +func (c *WebSocketClient) handleInvestigationTask(data interface{}) { + // Parse task data + taskBytes, err := json.Marshal(data) + if err != nil { + log.Printf("❌ Error marshaling task data: %v", err) + return + } + + var task InvestigationTask + err = json.Unmarshal(taskBytes, &task) + if err != nil { + log.Printf("❌ Error unmarshaling investigation task: %v", err) + return + } + + // Processing investigation task + + // Execute diagnostic commands + results, err := c.executeDiagnosticCommands(task.DiagnosticPayload) + + // Prepare task result + taskResult := TaskResult{ + TaskID: task.TaskID, + Success: err == nil, + } + + if err != nil { + taskResult.Error = err.Error() + fmt.Printf("❌ Task execution failed: %v\n", err) + } else { + taskResult.CommandResults = results + // Task executed successfully + } + + // Send result back + c.sendTaskResult(taskResult) +} + +// executeDiagnosticCommands executes the commands from a diagnostic response +func (c *WebSocketClient) executeDiagnosticCommands(diagnosticPayload map[string]interface{}) (map[string]interface{}, error) { + fmt.Println("🔧 Executing diagnostic commands...") + + results := map[string]interface{}{ + "agent_id": c.agentID, + "execution_time": time.Now().UTC().Format(time.RFC3339), + "command_results": []map[string]interface{}{}, + } + + // Extract commands from diagnostic payload + commands, ok := diagnosticPayload["commands"].([]interface{}) + if !ok { + return nil, fmt.Errorf("no commands found in diagnostic payload") + } + + var commandResults []map[string]interface{} + + for _, cmd := range commands { + cmdMap, ok := cmd.(map[string]interface{}) + if !ok { + continue + } + + id, _ := cmdMap["id"].(string) + command, _ := cmdMap["command"].(string) + description, _ := cmdMap["description"].(string) + + if command == "" { + continue + } + + // Executing command + + // Execute the command + output, exitCode, err := c.executeCommand(command) + + result := map[string]interface{}{ + "id": id, + "command": command, + "description": description, + "output": output, + "exit_code": exitCode, + "success": err == nil && exitCode == 0, + } + + if err != nil { + result["error"] = err.Error() + fmt.Printf("❌ Command [%s] failed: %v (exit code: %d)\n", id, err, exitCode) + } else { + // Command completed successfully - output captured + } + + commandResults = append(commandResults, result) + } + + results["command_results"] = commandResults + results["total_commands"] = len(commandResults) + results["successful_commands"] = c.countSuccessfulCommands(commandResults) + + // Execute eBPF programs if present + ebpfPrograms, hasEBPF := diagnosticPayload["ebpf_programs"].([]interface{}) + if hasEBPF && len(ebpfPrograms) > 0 { + fmt.Printf("đŸ”Ŧ Executing %d eBPF programs...\n", len(ebpfPrograms)) + ebpfResults := c.executeEBPFPrograms(ebpfPrograms) + results["ebpf_results"] = ebpfResults + results["total_ebpf_programs"] = len(ebpfPrograms) + } else { + fmt.Printf("â„šī¸ No eBPF programs in diagnostic payload\n") + } + + fmt.Printf("✅ Executed %d commands, %d successful\n", + results["total_commands"], results["successful_commands"]) + + return results, nil +} + +// executeEBPFPrograms executes eBPF monitoring programs using the real eBPF manager +func (c *WebSocketClient) executeEBPFPrograms(ebpfPrograms []interface{}) []map[string]interface{} { + var ebpfRequests []EBPFRequest + + // Convert interface{} to EBPFRequest structs + for _, prog := range ebpfPrograms { + progMap, ok := prog.(map[string]interface{}) + if !ok { + continue + } + + name, _ := progMap["name"].(string) + progType, _ := progMap["type"].(string) + target, _ := progMap["target"].(string) + duration, _ := progMap["duration"].(float64) + description, _ := progMap["description"].(string) + + if name == "" || progType == "" || target == "" { + continue + } + + ebpfRequests = append(ebpfRequests, EBPFRequest{ + Name: name, + Type: progType, + Target: target, + Duration: int(duration), + Description: description, + }) + } + + // Execute eBPF programs using the agent's eBPF execution logic + return c.agent.executeEBPFPrograms(ebpfRequests) +} + +// executeCommandsFromPayload executes commands from a payload and returns results +func (c *WebSocketClient) executeCommandsFromPayload(commands []interface{}) []map[string]interface{} { + var commandResults []map[string]interface{} + + for _, cmd := range commands { + cmdMap, ok := cmd.(map[string]interface{}) + if !ok { + continue + } + + id, _ := cmdMap["id"].(string) + command, _ := cmdMap["command"].(string) + description, _ := cmdMap["description"].(string) + + if command == "" { + continue + } + + // Execute the command + output, exitCode, err := c.executeCommand(command) + + result := map[string]interface{}{ + "id": id, + "command": command, + "description": description, + "output": output, + "exit_code": exitCode, + "success": err == nil && exitCode == 0, + } + + if err != nil { + result["error"] = err.Error() + fmt.Printf("❌ Command [%s] failed: %v (exit code: %d)\n", id, err, exitCode) + } else { + fmt.Printf("✅ Command [%s] completed successfully\n", id) + } + + commandResults = append(commandResults, result) + } + + return commandResults +} + +// executeCommand executes a shell command and returns output, exit code, and error +func (c *WebSocketClient) executeCommand(command string) (string, int, error) { + // Parse command into parts + parts := strings.Fields(command) + if len(parts) == 0 { + return "", -1, fmt.Errorf("empty command") + } + + // Create command with timeout + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + cmd := exec.CommandContext(ctx, parts[0], parts[1:]...) + cmd.Env = os.Environ() + + output, err := cmd.CombinedOutput() + exitCode := 0 + + if err != nil { + if exitError, ok := err.(*exec.ExitError); ok { + exitCode = exitError.ExitCode() + } else { + exitCode = -1 + } + } + + return string(output), exitCode, err +} + +// countSuccessfulCommands counts the number of successful commands +func (c *WebSocketClient) countSuccessfulCommands(results []map[string]interface{}) int { + count := 0 + for _, result := range results { + if success, ok := result["success"].(bool); ok && success { + count++ + } + } + return count +} + +// sendTaskResult sends a task result back to the backend +func (c *WebSocketClient) sendTaskResult(result TaskResult) { + message := WebSocketMessage{ + Type: "task_result", + Data: result, + } + + err := c.conn.WriteJSON(message) + if err != nil { + log.Printf("❌ Error sending task result: %v", err) + } +} + +// startHeartbeat sends periodic heartbeat messages +func (c *WebSocketClient) startHeartbeat() { + ticker := time.NewTicker(30 * time.Second) // Heartbeat every 30 seconds + defer ticker.Stop() + + // Starting heartbeat + + for { + select { + case <-c.ctx.Done(): + fmt.Printf("💓 Heartbeat stopped due to context cancellation\n") + return + case <-ticker.C: + // Sending heartbeat + heartbeat := WebSocketMessage{ + Type: "heartbeat", + Data: HeartbeatData{ + AgentID: c.agentID, + Timestamp: time.Now(), + Version: "v2.0.0", + }, + } + + err := c.conn.WriteJSON(heartbeat) + if err != nil { + log.Printf("❌ Error sending heartbeat: %v", err) + fmt.Printf("💓 Heartbeat failed, connection likely dead\n") + return + } + // Heartbeat sent + } + } +} + +// pollPendingInvestigations polls the database for pending investigations +func (c *WebSocketClient) pollPendingInvestigations() { + // Starting database polling + ticker := time.NewTicker(5 * time.Second) // Poll every 5 seconds + defer ticker.Stop() + + for { + select { + case <-c.ctx.Done(): + return + case <-ticker.C: + c.checkForPendingInvestigations() + } + } +} + +// checkForPendingInvestigations checks the database for new pending investigations via proxy +func (c *WebSocketClient) checkForPendingInvestigations() { + // Use Edge Function proxy instead of direct database access + url := fmt.Sprintf("%s/functions/v1/agent-database-proxy/pending-investigations", c.supabaseURL) + + // Poll database for pending investigations + + req, err := http.NewRequest("GET", url, nil) + if err != nil { + // Request creation failed + return + } + + // Only JWT token needed for proxy - no API keys exposed + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.token)) + req.Header.Set("Accept", "application/json") + + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Do(req) + if err != nil { + // Database request failed + return + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + return + } + + var investigations []PendingInvestigation + err = json.NewDecoder(resp.Body).Decode(&investigations) + if err != nil { + // Response decode failed + return + } + + for _, investigation := range investigations { + go c.handlePendingInvestigation(investigation) + } +} + +// handlePendingInvestigation processes a pending investigation from database polling +func (c *WebSocketClient) handlePendingInvestigation(investigation PendingInvestigation) { + // Processing pending investigation + + // Mark as executing + err := c.updateInvestigationStatus(investigation.ID, "executing", nil, nil) + if err != nil { + return + } + + // Execute diagnostic commands + results, err := c.executeDiagnosticCommands(investigation.DiagnosticPayload) + + // Prepare the base results map we'll send to DB + resultsForDB := map[string]interface{}{ + "agent_id": c.agentID, + "execution_time": time.Now().UTC().Format(time.RFC3339), + "command_results": results, + } + + // If command execution failed, mark investigation as failed + if err != nil { + errorMsg := err.Error() + // Include partial results when possible + if results != nil { + resultsForDB["command_results"] = results + } + c.updateInvestigationStatus(investigation.ID, "failed", resultsForDB, &errorMsg) + // Investigation failed + return + } + + // Try to continue the TensorZero conversation by sending command results back + // Build messages: assistant = diagnostic payload, user = command results + diagJSON, _ := json.Marshal(investigation.DiagnosticPayload) + commandsJSON, _ := json.MarshalIndent(results, "", " ") + + messages := []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleAssistant, + Content: string(diagJSON), + }, + { + Role: openai.ChatMessageRoleUser, + Content: string(commandsJSON), + }, + } + + fmt.Printf("🔄 Sending command results to TensorZero for continued analysis...\n") + fmt.Printf("📤 Command results payload size: %d bytes\n", len(commandsJSON)) + + // Use the episode ID from the investigation to maintain conversation continuity + episodeID := "" + if investigation.EpisodeID != nil { + episodeID = *investigation.EpisodeID + fmt.Printf("🔗 Using episode ID: %s\n", episodeID) + } + + // Continue conversation until resolution (same as agent) + var finalAIContent string + for { + tzResp, tzErr := c.agent.sendRequestWithEpisode(messages, episodeID) + if tzErr != nil { + fmt.Printf("âš ī¸ TensorZero continuation failed: %v\n", tzErr) + // Fall back to marking completed with command results only + c.updateInvestigationStatus(investigation.ID, "completed", resultsForDB, nil) + return + } + + fmt.Printf("✅ TensorZero responded successfully\n") + + if len(tzResp.Choices) == 0 { + fmt.Printf("âš ī¸ No choices in TensorZero response\n") + c.updateInvestigationStatus(investigation.ID, "completed", resultsForDB, nil) + return + } + + aiContent := tzResp.Choices[0].Message.Content + if len(aiContent) > 300 { + fmt.Printf("🤖 AI Response preview: %s...\n", aiContent[:300]) + } else { + fmt.Printf("🤖 AI Response: %s\n", aiContent) + } + + // Check if this is a resolution response (final) + var resolutionResp struct { + ResponseType string `json:"response_type"` + RootCause string `json:"root_cause"` + ResolutionPlan string `json:"resolution_plan"` + Confidence string `json:"confidence"` + } + + fmt.Printf("🔍 Analyzing AI response type...\n") + + if err := json.Unmarshal([]byte(aiContent), &resolutionResp); err == nil && resolutionResp.ResponseType == "resolution" { + // This is the final resolution - show summary and complete + fmt.Printf("✅ Detected RESOLUTION response - completing investigation\n") + fmt.Printf("\n=== DIAGNOSIS COMPLETE ===\n") + fmt.Printf("Root Cause: %s\n", resolutionResp.RootCause) + fmt.Printf("Resolution Plan: %s\n", resolutionResp.ResolutionPlan) + fmt.Printf("Confidence: %s\n", resolutionResp.Confidence) + finalAIContent = aiContent + break + } + + // Check if this is another diagnostic response requiring more commands + var diagnosticResp struct { + ResponseType string `json:"response_type"` + Commands []interface{} `json:"commands"` + EBPFPrograms []interface{} `json:"ebpf_programs"` + } + + if err := json.Unmarshal([]byte(aiContent), &diagnosticResp); err == nil && diagnosticResp.ResponseType == "diagnostic" { + fmt.Printf("✅ Detected DIAGNOSTIC response - continuing conversation\n") + fmt.Printf("🔄 AI requested additional diagnostics, executing...\n") + + // Execute additional commands if any + additionalResults := map[string]interface{}{ + "command_results": []map[string]interface{}{}, + } + + if len(diagnosticResp.Commands) > 0 { + fmt.Printf("🔧 Executing %d additional diagnostic commands...\n", len(diagnosticResp.Commands)) + commandResults := c.executeCommandsFromPayload(diagnosticResp.Commands) + additionalResults["command_results"] = commandResults + } + + // Execute additional eBPF programs if any + if len(diagnosticResp.EBPFPrograms) > 0 { + fmt.Printf("đŸ”Ŧ Executing %d additional eBPF programs...\n", len(diagnosticResp.EBPFPrograms)) + ebpfResults := c.executeEBPFPrograms(diagnosticResp.EBPFPrograms) + additionalResults["ebpf_results"] = ebpfResults + } + + // Add AI response and additional results to conversation + messages = append(messages, openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleAssistant, + Content: aiContent, + }) + + additionalResultsJSON, _ := json.MarshalIndent(additionalResults, "", " ") + messages = append(messages, openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleUser, + Content: string(additionalResultsJSON), + }) + + continue + } + + // If neither resolution nor diagnostic, treat as final response + fmt.Printf("âš ī¸ Unknown response type - treating as final response\n") + finalAIContent = aiContent + break + } + + // Attach final AI response to results for DB and mark as completed_with_analysis + resultsForDB["ai_response"] = finalAIContent + fmt.Printf("💾 Updating database with results and AI analysis...\n") + c.updateInvestigationStatus(investigation.ID, "completed_with_analysis", resultsForDB, nil) + fmt.Printf("✅ Investigation completed with AI analysis\n") +} + +// updateInvestigationStatus updates the status of a pending investigation +func (c *WebSocketClient) updateInvestigationStatus(id, status string, results map[string]interface{}, errorMsg *string) error { + updateData := map[string]interface{}{ + "status": status, + } + + if status == "executing" { + updateData["started_at"] = time.Now().UTC().Format(time.RFC3339) + } else if status == "completed" { + updateData["completed_at"] = time.Now().UTC().Format(time.RFC3339) + if results != nil { + updateData["command_results"] = results + } + } else if status == "failed" && errorMsg != nil { + updateData["error_message"] = *errorMsg + updateData["completed_at"] = time.Now().UTC().Format(time.RFC3339) + } + + jsonData, err := json.Marshal(updateData) + if err != nil { + return fmt.Errorf("failed to marshal update data: %v", err) + } + + url := fmt.Sprintf("%s/functions/v1/agent-database-proxy/pending-investigations/%s", c.supabaseURL, id) + req, err := http.NewRequest("PATCH", url, strings.NewReader(string(jsonData))) + if err != nil { + return fmt.Errorf("failed to create request: %v", err) + } + + // Only JWT token needed for proxy - no API keys exposed + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.token)) + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("failed to update investigation: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != 200 && resp.StatusCode != 204 { + return fmt.Errorf("supabase update error: %d", resp.StatusCode) + } + + return nil +} + +// attemptReconnection attempts to reconnect the WebSocket with backoff +func (c *WebSocketClient) attemptReconnection() { + backoffDurations := []time.Duration{ + 2 * time.Second, + 5 * time.Second, + 10 * time.Second, + 20 * time.Second, + 30 * time.Second, + } + + for i, backoff := range backoffDurations { + select { + case <-c.ctx.Done(): + return + default: + c.consecutiveFailures++ + + // Only show messages after 5 consecutive failures + if c.consecutiveFailures >= 5 { + log.Printf("🔄 Attempting WebSocket reconnection (attempt %d/%d) - %d consecutive failures", i+1, len(backoffDurations), c.consecutiveFailures) + } + + time.Sleep(backoff) + + if err := c.connect(); err != nil { + if c.consecutiveFailures >= 5 { + log.Printf("❌ Reconnection attempt %d failed: %v", i+1, err) + } + continue + } + + // Successfully reconnected - reset failure counter + if c.consecutiveFailures >= 5 { + log.Printf("✅ WebSocket reconnected successfully after %d failures", c.consecutiveFailures) + } + c.consecutiveFailures = 0 + go c.handleMessages() // Restart message handling + return + } + } + + log.Printf("❌ Failed to reconnect after %d attempts, giving up", len(backoffDurations)) +}