Compare commits

...

4 Commits

Author SHA256 Message Date
Harshavardhan Musanalli
8328f8d5b3 Integrate-with-supabase-backend 2025-10-28 07:53:14 +01:00
Harshavardhan Musanalli
8832450a1f Agent and websocket investigations work fine 2025-10-27 19:13:39 +01:00
Harshavardhan Musanalli
0a8b2dc202 Working code with Tensorzero through Supabase proxy 2025-10-25 15:16:03 +02:00
Harshavardhan Musanalli
6fd403cb5f Integrate with supabase backend 2025-10-25 12:39:48 +02:00
18 changed files with 3155 additions and 72 deletions

3
.gitignore vendored
View File

@@ -23,6 +23,7 @@ go.work
go.work.sum
# env file
.env
.env*
nannyagent*
nanny-agent*
.vscode

View File

250
agent.go
View File

@@ -55,10 +55,11 @@ type LinuxDiagnosticAgent struct {
// NewLinuxDiagnosticAgent creates a new diagnostic agent
func NewLinuxDiagnosticAgent() *LinuxDiagnosticAgent {
endpoint := os.Getenv("NANNYAPI_ENDPOINT")
if endpoint == "" {
// Default endpoint - OpenAI SDK will append /chat/completions automatically
endpoint = "http://tensorzero.netcup.internal:3000/openai/v1"
// Get Supabase project URL for TensorZero proxy
supabaseURL := os.Getenv("SUPABASE_PROJECT_URL")
if supabaseURL == "" {
fmt.Printf("Warning: SUPABASE_PROJECT_URL not set, TensorZero integration will not work\n")
supabaseURL = "https://gpqzsricripnvbrpsyws.supabase.co" // fallback
}
model := os.Getenv("NANNYAPI_MODEL")
@@ -67,14 +68,9 @@ func NewLinuxDiagnosticAgent() *LinuxDiagnosticAgent {
fmt.Printf("Warning: Using default model '%s'. Set NANNYAPI_MODEL environment variable for your specific function.\n", model)
}
// Create OpenAI client with custom base URL
// Note: The OpenAI SDK automatically appends "/chat/completions" to the base URL
config := openai.DefaultConfig("")
config.BaseURL = endpoint
client := openai.NewClientWithConfig(config)
// Note: We don't use the OpenAI client anymore, we use direct HTTP to Supabase proxy
agent := &LinuxDiagnosticAgent{
client: client,
client: nil, // Not used anymore
model: model,
executor: NewCommandExecutor(10 * time.Second), // 10 second timeout for commands
}
@@ -106,7 +102,7 @@ func (a *LinuxDiagnosticAgent) DiagnoseIssue(issue string) error {
for {
// Send request to TensorZero API via OpenAI SDK
response, err := a.sendRequest(messages)
response, err := a.sendRequestWithEpisode(messages, a.episodeID)
if err != nil {
return fmt.Errorf("failed to send request: %w", err)
}
@@ -119,34 +115,60 @@ func (a *LinuxDiagnosticAgent) DiagnoseIssue(issue string) error {
fmt.Printf("\nAI Response:\n%s\n", content)
// Parse the response to determine next action
var diagnosticResp DiagnosticResponse
var diagnosticResp EBPFEnhancedDiagnosticResponse
var resolutionResp ResolutionResponse
// Try to parse as diagnostic response first
// Try to parse as diagnostic response first (with eBPF support)
if err := json.Unmarshal([]byte(content), &diagnosticResp); err == nil && diagnosticResp.ResponseType == "diagnostic" {
// Handle diagnostic phase
fmt.Printf("\nReasoning: %s\n", diagnosticResp.Reasoning)
if len(diagnosticResp.Commands) == 0 {
fmt.Println("No commands to execute in diagnostic phase")
break
}
// Execute commands and collect results
commandResults := make([]CommandResult, 0, len(diagnosticResp.Commands))
if len(diagnosticResp.Commands) > 0 {
fmt.Printf("🔧 Executing diagnostic commands...\n")
for _, cmd := range diagnosticResp.Commands {
fmt.Printf("\nExecuting command '%s': %s\n", cmd.ID, cmd.Command)
result := a.executor.Execute(cmd)
commandResults = append(commandResults, result)
fmt.Printf("Output:\n%s\n", result.Output)
if result.Error != "" {
fmt.Printf("Error: %s\n", result.Error)
if result.ExitCode != 0 {
fmt.Printf("❌ Command '%s' failed with exit code %d\n", cmd.ID, result.ExitCode)
}
}
}
// Prepare command results as user message
resultsJSON, err := json.MarshalIndent(commandResults, "", " ")
// Execute eBPF programs if present
var ebpfResults []map[string]interface{}
if len(diagnosticResp.EBPFPrograms) > 0 {
ebpfResults = a.executeEBPFPrograms(diagnosticResp.EBPFPrograms)
}
// Prepare combined results as user message
allResults := map[string]interface{}{
"command_results": commandResults,
"executed_commands": len(commandResults),
}
// Include eBPF results if any were executed
if len(ebpfResults) > 0 {
allResults["ebpf_results"] = ebpfResults
allResults["executed_ebpf_programs"] = len(ebpfResults)
// Extract evidence summary for TensorZero
evidenceSummary := make([]string, 0)
for _, result := range ebpfResults {
name := result["name"]
eventCount := result["data_points"]
description := result["description"]
status := result["status"]
summaryStr := fmt.Sprintf("%s: %v events (%s) - %s", name, eventCount, status, description)
evidenceSummary = append(evidenceSummary, summaryStr)
}
allResults["ebpf_evidence_summary"] = evidenceSummary
}
resultsJSON, err := json.MarshalIndent(allResults, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal command results: %w", err)
}
@@ -182,6 +204,119 @@ func (a *LinuxDiagnosticAgent) DiagnoseIssue(issue string) error {
return nil
}
// executeEBPFPrograms executes REAL eBPF monitoring programs using the actual eBPF manager
func (a *LinuxDiagnosticAgent) executeEBPFPrograms(ebpfPrograms []EBPFRequest) []map[string]interface{} {
var results []map[string]interface{}
if a.ebpfManager == nil {
fmt.Printf("❌ eBPF manager not initialized\n")
return results
}
for _, prog := range ebpfPrograms {
// eBPF program starting - only show in debug mode
// Actually start the eBPF program using the real manager
programID, err := a.ebpfManager.StartEBPFProgram(prog)
if err != nil {
fmt.Printf("❌ Failed to start eBPF program [%s]: %v\n", prog.Name, err)
result := map[string]interface{}{
"name": prog.Name,
"type": prog.Type,
"target": prog.Target,
"duration": int(prog.Duration),
"description": prog.Description,
"status": "failed",
"error": err.Error(),
"success": false,
}
results = append(results, result)
continue
}
// Let the eBPF program run for the specified duration
time.Sleep(time.Duration(prog.Duration) * time.Second)
// Give the collectEvents goroutine a moment to finish and store results
time.Sleep(500 * time.Millisecond)
// Use a channel to implement timeout for GetProgramResults
type resultPair struct {
trace *EBPFTrace
err error
}
resultChan := make(chan resultPair, 1)
go func() {
trace, err := a.ebpfManager.GetProgramResults(programID)
resultChan <- resultPair{trace, err}
}()
var trace *EBPFTrace
var resultErr error
select {
case result := <-resultChan:
trace = result.trace
resultErr = result.err
case <-time.After(3 * time.Second):
resultErr = fmt.Errorf("timeout getting results after 3 seconds")
}
// Try to stop the program (may already be stopped by collectEvents)
stopErr := a.ebpfManager.StopProgram(programID)
if stopErr != nil {
// Only show warning in debug mode - this is normal for completed programs
}
if resultErr != nil {
fmt.Printf("❌ Failed to get results for eBPF program [%s]: %v\n", prog.Name, resultErr)
result := map[string]interface{}{
"name": prog.Name,
"type": prog.Type,
"target": prog.Target,
"duration": int(prog.Duration),
"description": prog.Description,
"status": "collection_failed",
"error": resultErr.Error(),
"success": false,
}
results = append(results, result)
continue
} // Process the real eBPF trace data
result := map[string]interface{}{
"name": prog.Name,
"type": prog.Type,
"target": prog.Target,
"duration": int(prog.Duration),
"description": prog.Description,
"status": "completed",
"success": true,
}
// Extract real data from the trace
if trace != nil {
result["trace_id"] = trace.TraceID
result["data_points"] = trace.EventCount
result["events"] = trace.Events
result["summary"] = trace.Summary
result["process_list"] = trace.ProcessList
result["start_time"] = trace.StartTime.Format(time.RFC3339)
result["end_time"] = trace.EndTime.Format(time.RFC3339)
result["actual_duration"] = trace.EndTime.Sub(trace.StartTime).Seconds()
} else {
result["data_points"] = 0
result["error"] = "No trace data returned"
fmt.Printf("⚠️ eBPF program [%s] completed but returned no trace data\n", prog.Name)
}
results = append(results, result)
}
return results
}
// TensorZeroRequest represents a request structure compatible with TensorZero's episode_id
type TensorZeroRequest struct {
Model string `json:"model"`
@@ -195,8 +330,13 @@ type TensorZeroResponse struct {
EpisodeID string `json:"episode_id"`
}
// sendRequest sends a request to the TensorZero API with tensorzero::episode_id support
// sendRequest sends a request to the TensorZero API via Supabase proxy with JWT authentication
func (a *LinuxDiagnosticAgent) sendRequest(messages []openai.ChatCompletionMessage) (*openai.ChatCompletionResponse, error) {
return a.sendRequestWithEpisode(messages, "")
}
// sendRequestWithEpisode sends a request with a specific episode ID
func (a *LinuxDiagnosticAgent) sendRequestWithEpisode(messages []openai.ChatCompletionMessage, episodeID string) (*openai.ChatCompletionResponse, error) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
@@ -206,9 +346,12 @@ func (a *LinuxDiagnosticAgent) sendRequest(messages []openai.ChatCompletionMessa
Messages: messages,
}
// Include tensorzero::episode_id for conversation continuity (if we have one)
// Include tensorzero::episode_id for conversation continuity
// Use agent's existing episode ID if available, otherwise use provided one
if a.episodeID != "" {
tzRequest.EpisodeID = a.episodeID
} else if episodeID != "" {
tzRequest.EpisodeID = episodeID
}
fmt.Printf("Debug: Sending request to model: %s", a.model)
@@ -223,17 +366,14 @@ func (a *LinuxDiagnosticAgent) sendRequest(messages []openai.ChatCompletionMessa
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
// Create HTTP request
endpoint := os.Getenv("NANNYAPI_ENDPOINT")
if endpoint == "" {
endpoint = "http://tensorzero.netcup.internal:3000/openai/v1"
// Get Supabase project URL and build TensorZero proxy endpoint
supabaseURL := os.Getenv("SUPABASE_PROJECT_URL")
if supabaseURL == "" {
supabaseURL = "https://gpqzsricripnvbrpsyws.supabase.co"
}
// Ensure the endpoint ends with /chat/completions
if endpoint[len(endpoint)-1] != '/' {
endpoint += "/"
}
endpoint += "chat/completions"
// Build Supabase function URL with OpenAI v1 compatible path
endpoint := supabaseURL + "/functions/v1/tensorzero-proxy/openai/v1/chat/completions"
req, err := http.NewRequestWithContext(ctx, "POST", endpoint, bytes.NewBuffer(requestBody))
if err != nil {
@@ -242,6 +382,14 @@ func (a *LinuxDiagnosticAgent) sendRequest(messages []openai.ChatCompletionMessa
req.Header.Set("Content-Type", "application/json")
// Add JWT authentication header
accessToken, err := a.getAccessToken()
if err != nil {
return nil, fmt.Errorf("failed to get access token: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
// Make the request
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
@@ -257,7 +405,7 @@ func (a *LinuxDiagnosticAgent) sendRequest(messages []openai.ChatCompletionMessa
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
return nil, fmt.Errorf("TensorZero API request failed with status %d: %s", resp.StatusCode, string(body))
}
// Parse TensorZero response
@@ -274,3 +422,31 @@ func (a *LinuxDiagnosticAgent) sendRequest(messages []openai.ChatCompletionMessa
return &tzResponse.ChatCompletionResponse, nil
}
// getAccessToken retrieves the current access token for authentication
func (a *LinuxDiagnosticAgent) getAccessToken() (string, error) {
// Read token from the standard token file location
tokenPath := os.Getenv("TOKEN_PATH")
if tokenPath == "" {
tokenPath = "/var/lib/nannyagent/token.json"
}
tokenData, err := os.ReadFile(tokenPath)
if err != nil {
return "", fmt.Errorf("failed to read token file: %w", err)
}
var tokenInfo struct {
AccessToken string `json:"access_token"`
}
if err := json.Unmarshal(tokenData, &tokenInfo); err != nil {
return "", fmt.Errorf("failed to parse token file: %w", err)
}
if tokenInfo.AccessToken == "" {
return "", fmt.Errorf("access token is empty")
}
return tokenInfo.AccessToken, nil
}

14
go.mod
View File

@@ -6,7 +6,19 @@ toolchain go1.24.2
require (
github.com/cilium/ebpf v0.19.0
github.com/joho/godotenv v1.5.1
github.com/sashabaranov/go-openai v1.32.0
github.com/shirou/gopsutil/v3 v3.24.5
)
require golang.org/x/sys v0.31.0 // indirect
require (
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/gorilla/websocket v1.5.3 // indirect
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
github.com/shoenig/go-m1cpu v0.1.6 // indirect
github.com/tklauser/go-sysconf v0.3.12 // indirect
github.com/tklauser/numcpus v0.6.1 // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect
golang.org/x/sys v0.31.0 // indirect
)

36
go.sum
View File

@@ -1,9 +1,18 @@
github.com/cilium/ebpf v0.19.0 h1:Ro/rE64RmFBeA9FGjcTc+KmCeY6jXmryu6FfnzPRIao=
github.com/cilium/ebpf v0.19.0/go.mod h1:fLCgMo3l8tZmAdM3B2XqdFzXBpwkcSTroaVqN08OWVY=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-quicktest/qt v1.101.1-0.20240301121107-c6c8733fa1e6 h1:teYtXy9B7y5lHTp8V9KPxpYRAVA7dozigQcMiBust1s=
github.com/go-quicktest/qt v1.101.1-0.20240301121107-c6c8733fa1e6/go.mod h1:p4lGIVX+8Wa6ZPNDvqcxq36XpUDLh42FLetFU7odllI=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA=
github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
github.com/jsimonetti/rtnetlink/v2 v2.0.1 h1:xda7qaHDSVOsADNouv7ukSuicKZO7GgVUCXxpaIEIlM=
@@ -12,17 +21,44 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g=
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U=
github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
github.com/sashabaranov/go-openai v1.32.0 h1:Yk3iE9moX3RBXxrof3OBtUBrE7qZR0zF9ebsoO4zVzI=
github.com/sashabaranov/go-openai v1.32.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/shirou/gopsutil/v3 v3.24.5 h1:i0t8kL+kQTvpAYToeuiVk3TgDeKOFioZO3Ztz/iZ9pI=
github.com/shirou/gopsutil/v3 v3.24.5/go.mod h1:bsoOS1aStSs9ErQ1WWfxllSeS1K5D+U30r2NfcubMVk=
github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM=
github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ=
github.com/shoenig/test v0.6.4 h1:kVTaSd7WLz5WZ2IaoM0RSzRsUD+m8wRR+5qvntpn4LU=
github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnjqq0k=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY=
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

509
internal/auth/auth.go Normal file
View File

@@ -0,0 +1,509 @@
package auth
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"nannyagentv2/internal/config"
"nannyagentv2/internal/types"
)
const (
// Token storage location (secure directory)
TokenStorageDir = "/var/lib/nannyagent"
TokenStorageFile = ".agent_token.json"
RefreshTokenFile = ".refresh_token"
// Polling configuration
MaxPollAttempts = 60 // 5 minutes (60 * 5 seconds)
PollInterval = 5 * time.Second
)
// AuthManager handles all authentication-related operations
type AuthManager struct {
config *config.Config
client *http.Client
}
// NewAuthManager creates a new authentication manager
func NewAuthManager(cfg *config.Config) *AuthManager {
return &AuthManager{
config: cfg,
client: &http.Client{
Timeout: 30 * time.Second,
},
}
}
// EnsureTokenStorageDir creates the token storage directory if it doesn't exist
func (am *AuthManager) EnsureTokenStorageDir() error {
// Check if running as root
if os.Geteuid() != 0 {
return fmt.Errorf("must run as root to create secure token storage directory")
}
// Create directory with restricted permissions (0700 - only root can access)
if err := os.MkdirAll(TokenStorageDir, 0700); err != nil {
return fmt.Errorf("failed to create token storage directory: %w", err)
}
return nil
}
// StartDeviceAuthorization initiates the OAuth device authorization flow
func (am *AuthManager) StartDeviceAuthorization() (*types.DeviceAuthResponse, error) {
payload := map[string]interface{}{
"client_id": "nannyagent-cli",
"scope": []string{"agent:register"},
}
jsonData, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("failed to marshal payload: %w", err)
}
url := fmt.Sprintf("%s/device/authorize", am.config.DeviceAuthURL)
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := am.client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to start device authorization: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("device authorization failed with status %d: %s", resp.StatusCode, string(body))
}
var deviceResp types.DeviceAuthResponse
if err := json.Unmarshal(body, &deviceResp); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
return &deviceResp, nil
}
// PollForToken polls the token endpoint until authorization is complete
func (am *AuthManager) PollForToken(deviceCode string) (*types.TokenResponse, error) {
fmt.Println("⏳ Waiting for user authorization...")
for attempts := 0; attempts < MaxPollAttempts; attempts++ {
tokenReq := types.TokenRequest{
GrantType: "urn:ietf:params:oauth:grant-type:device_code",
DeviceCode: deviceCode,
}
jsonData, err := json.Marshal(tokenReq)
if err != nil {
return nil, fmt.Errorf("failed to marshal token request: %w", err)
}
url := fmt.Sprintf("%s/token", am.config.DeviceAuthURL)
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create token request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := am.client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to poll for token: %w", err)
}
body, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return nil, fmt.Errorf("failed to read token response: %w", err)
}
var tokenResp types.TokenResponse
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, fmt.Errorf("failed to parse token response: %w", err)
}
if tokenResp.Error != "" {
if tokenResp.Error == "authorization_pending" {
fmt.Print(".")
time.Sleep(PollInterval)
continue
}
return nil, fmt.Errorf("authorization failed: %s", tokenResp.ErrorDescription)
}
if tokenResp.AccessToken != "" {
fmt.Println("\n✅ Authorization successful!")
return &tokenResp, nil
}
time.Sleep(PollInterval)
}
return nil, fmt.Errorf("authorization timed out after %d attempts", MaxPollAttempts)
}
// RefreshAccessToken refreshes an expired access token using the refresh token
func (am *AuthManager) RefreshAccessToken(refreshToken string) (*types.TokenResponse, error) {
tokenReq := types.TokenRequest{
GrantType: "refresh_token",
RefreshToken: refreshToken,
}
jsonData, err := json.Marshal(tokenReq)
if err != nil {
return nil, fmt.Errorf("failed to marshal refresh request: %w", err)
}
url := fmt.Sprintf("%s/token", am.config.DeviceAuthURL)
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create refresh request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := am.client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to refresh token: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read refresh response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("token refresh failed with status %d: %s", resp.StatusCode, string(body))
}
var tokenResp types.TokenResponse
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, fmt.Errorf("failed to parse refresh response: %w", err)
}
if tokenResp.Error != "" {
return nil, fmt.Errorf("token refresh failed: %s", tokenResp.ErrorDescription)
}
return &tokenResp, nil
}
// SaveToken saves the authentication token to secure local storage
func (am *AuthManager) SaveToken(token *types.AuthToken) error {
if err := am.EnsureTokenStorageDir(); err != nil {
return fmt.Errorf("failed to ensure token storage directory: %w", err)
}
// Save main token file
tokenPath := am.getTokenPath()
jsonData, err := json.MarshalIndent(token, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal token: %w", err)
}
if err := os.WriteFile(tokenPath, jsonData, 0600); err != nil {
return fmt.Errorf("failed to save token: %w", err)
}
// Also save refresh token separately for backup recovery
if token.RefreshToken != "" {
refreshTokenPath := filepath.Join(TokenStorageDir, RefreshTokenFile)
if err := os.WriteFile(refreshTokenPath, []byte(token.RefreshToken), 0600); err != nil {
// Don't fail if refresh token backup fails, just log
fmt.Printf("Warning: Failed to save backup refresh token: %v\n", err)
}
}
return nil
} // LoadToken loads the authentication token from secure local storage
func (am *AuthManager) LoadToken() (*types.AuthToken, error) {
tokenPath := am.getTokenPath()
data, err := os.ReadFile(tokenPath)
if err != nil {
return nil, fmt.Errorf("failed to read token file: %w", err)
}
var token types.AuthToken
if err := json.Unmarshal(data, &token); err != nil {
return nil, fmt.Errorf("failed to parse token: %w", err)
}
// Check if token is expired
if time.Now().After(token.ExpiresAt.Add(-5 * time.Minute)) {
return nil, fmt.Errorf("token is expired or expiring soon")
}
return &token, nil
}
// IsTokenExpired checks if a token needs refresh
func (am *AuthManager) IsTokenExpired(token *types.AuthToken) bool {
// Consider token expired if it expires within the next 5 minutes
return time.Now().After(token.ExpiresAt.Add(-5 * time.Minute))
}
// RegisterDevice performs the complete device registration flow
func (am *AuthManager) RegisterDevice() (*types.AuthToken, error) {
// Step 1: Start device authorization
deviceAuth, err := am.StartDeviceAuthorization()
if err != nil {
return nil, fmt.Errorf("failed to start device authorization: %w", err)
}
fmt.Printf("Please visit: %s\n", deviceAuth.VerificationURI)
fmt.Printf("And enter code: %s\n", deviceAuth.UserCode)
// Step 2: Poll for token
tokenResp, err := am.PollForToken(deviceAuth.DeviceCode)
if err != nil {
return nil, fmt.Errorf("failed to get token: %w", err)
}
// Step 3: Create token storage
token := &types.AuthToken{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
TokenType: tokenResp.TokenType,
ExpiresAt: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second),
AgentID: tokenResp.AgentID,
}
// Step 4: Save token
if err := am.SaveToken(token); err != nil {
return nil, fmt.Errorf("failed to save token: %w", err)
}
return token, nil
}
// EnsureAuthenticated ensures the agent has a valid token, refreshing if necessary
func (am *AuthManager) EnsureAuthenticated() (*types.AuthToken, error) {
// Try to load existing token
token, err := am.LoadToken()
if err == nil && !am.IsTokenExpired(token) {
return token, nil
}
// Try to refresh with existing refresh token (even if access token is missing/expired)
var refreshToken string
if err == nil && token.RefreshToken != "" {
// Use refresh token from loaded token
refreshToken = token.RefreshToken
} else {
// Try to load refresh token from main token file even if load failed
if existingToken, loadErr := am.loadTokenIgnoringExpiry(); loadErr == nil && existingToken.RefreshToken != "" {
refreshToken = existingToken.RefreshToken
} else {
// Try to load refresh token from backup file
if backupRefreshToken, backupErr := am.loadRefreshTokenFromBackup(); backupErr == nil {
refreshToken = backupRefreshToken
fmt.Println("🔄 Found backup refresh token, attempting to use it...")
}
}
}
if refreshToken != "" {
fmt.Println("🔄 Attempting to refresh access token...")
refreshResp, refreshErr := am.RefreshAccessToken(refreshToken)
if refreshErr == nil {
// Get existing agent_id from current token or backup
var agentID string
if err == nil && token.AgentID != "" {
agentID = token.AgentID
} else if existingToken, loadErr := am.loadTokenIgnoringExpiry(); loadErr == nil {
agentID = existingToken.AgentID
}
// Create new token with refreshed values
newToken := &types.AuthToken{
AccessToken: refreshResp.AccessToken,
RefreshToken: refreshToken, // Keep existing refresh token
TokenType: refreshResp.TokenType,
ExpiresAt: time.Now().Add(time.Duration(refreshResp.ExpiresIn) * time.Second),
AgentID: agentID, // Preserve agent_id
}
// Update refresh token if a new one was provided
if refreshResp.RefreshToken != "" {
newToken.RefreshToken = refreshResp.RefreshToken
}
if saveErr := am.SaveToken(newToken); saveErr == nil {
return newToken, nil
}
} else {
fmt.Printf("⚠️ Token refresh failed: %v\n", refreshErr)
}
}
fmt.Println("📝 Initiating new device registration...")
return am.RegisterDevice()
}
// loadTokenIgnoringExpiry loads token file without checking expiry
func (am *AuthManager) loadTokenIgnoringExpiry() (*types.AuthToken, error) {
tokenPath := am.getTokenPath()
data, err := os.ReadFile(tokenPath)
if err != nil {
return nil, fmt.Errorf("failed to read token file: %w", err)
}
var token types.AuthToken
if err := json.Unmarshal(data, &token); err != nil {
return nil, fmt.Errorf("failed to parse token: %w", err)
}
return &token, nil
}
// loadRefreshTokenFromBackup tries to load refresh token from backup file
func (am *AuthManager) loadRefreshTokenFromBackup() (string, error) {
refreshTokenPath := filepath.Join(TokenStorageDir, RefreshTokenFile)
data, err := os.ReadFile(refreshTokenPath)
if err != nil {
return "", fmt.Errorf("failed to read refresh token backup: %w", err)
}
refreshToken := strings.TrimSpace(string(data))
if refreshToken == "" {
return "", fmt.Errorf("refresh token backup is empty")
}
return refreshToken, nil
}
// GetCurrentAgentID retrieves the agent ID from cache or JWT token
func (am *AuthManager) GetCurrentAgentID() (string, error) {
// First try to read from local cache
agentID, err := am.loadCachedAgentID()
if err == nil && agentID != "" {
return agentID, nil
}
// Cache miss - extract from JWT token and cache it
token, err := am.LoadToken()
if err != nil {
return "", fmt.Errorf("failed to load token: %w", err)
}
// Extract agent ID from JWT 'sub' field
agentID, err = am.extractAgentIDFromJWT(token.AccessToken)
if err != nil {
return "", fmt.Errorf("failed to extract agent ID from JWT: %w", err)
}
// Cache the agent ID for future use
if err := am.cacheAgentID(agentID); err != nil {
// Log warning but don't fail - we still have the agent ID
fmt.Printf("Warning: Failed to cache agent ID: %v\n", err)
}
return agentID, nil
}
// extractAgentIDFromJWT decodes the JWT token and extracts the agent ID from 'sub' field
func (am *AuthManager) extractAgentIDFromJWT(tokenString string) (string, error) {
// Basic JWT decoding without verification (since we trust Supabase)
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
return "", fmt.Errorf("invalid JWT token format")
}
// Decode the payload (second part)
payload := parts[1]
// Add padding if needed for base64 decoding
for len(payload)%4 != 0 {
payload += "="
}
decoded, err := base64.URLEncoding.DecodeString(payload)
if err != nil {
return "", fmt.Errorf("failed to decode JWT payload: %w", err)
}
// Parse JSON payload
var claims map[string]interface{}
if err := json.Unmarshal(decoded, &claims); err != nil {
return "", fmt.Errorf("failed to parse JWT claims: %w", err)
}
// The agent ID is in the 'sub' field (subject)
if agentID, ok := claims["sub"].(string); ok && agentID != "" {
return agentID, nil
}
return "", fmt.Errorf("agent ID (sub) not found in JWT claims")
}
// loadCachedAgentID reads the cached agent ID from local storage
func (am *AuthManager) loadCachedAgentID() (string, error) {
agentIDPath := filepath.Join(TokenStorageDir, "agent_id")
data, err := os.ReadFile(agentIDPath)
if err != nil {
return "", fmt.Errorf("failed to read cached agent ID: %w", err)
}
agentID := strings.TrimSpace(string(data))
if agentID == "" {
return "", fmt.Errorf("cached agent ID is empty")
}
return agentID, nil
}
// cacheAgentID stores the agent ID in local cache
func (am *AuthManager) cacheAgentID(agentID string) error {
// Ensure the directory exists
if err := am.EnsureTokenStorageDir(); err != nil {
return fmt.Errorf("failed to ensure storage directory: %w", err)
}
agentIDPath := filepath.Join(TokenStorageDir, "agent_id")
// Write agent ID to file with secure permissions
if err := os.WriteFile(agentIDPath, []byte(agentID), 0600); err != nil {
return fmt.Errorf("failed to write agent ID cache: %w", err)
}
return nil
}
func (am *AuthManager) getTokenPath() string {
if am.config.TokenPath != "" {
return am.config.TokenPath
}
return filepath.Join(TokenStorageDir, TokenStorageFile)
}
func getHostname() string {
if hostname, err := os.Hostname(); err == nil {
return hostname
}
return "unknown"
}

131
internal/config/config.go Normal file
View File

@@ -0,0 +1,131 @@
package config
import (
"fmt"
"os"
"path/filepath"
"strings"
"github.com/joho/godotenv"
)
type Config struct {
// Supabase Configuration
SupabaseProjectURL string
// Edge Function Endpoints (auto-generated from SupabaseProjectURL)
DeviceAuthURL string
AgentAuthURL string
// Agent Configuration
TokenPath string
MetricsInterval int
// Debug/Development
Debug bool
}
var DefaultConfig = Config{
TokenPath: "./token.json",
MetricsInterval: 30,
Debug: false,
}
// LoadConfig loads configuration from environment variables and .env file
func LoadConfig() (*Config, error) {
config := DefaultConfig
// Try to load .env file from current directory or parent directories
envFile := findEnvFile()
if envFile != "" {
if err := godotenv.Load(envFile); err != nil {
fmt.Printf("Warning: Could not load .env file from %s: %v\n", envFile, err)
} else {
fmt.Printf("Loaded configuration from %s\n", envFile)
}
}
// Load from environment variables
if url := os.Getenv("SUPABASE_PROJECT_URL"); url != "" {
config.SupabaseProjectURL = url
}
if tokenPath := os.Getenv("TOKEN_PATH"); tokenPath != "" {
config.TokenPath = tokenPath
}
if debug := os.Getenv("DEBUG"); debug == "true" || debug == "1" {
config.Debug = true
}
// Auto-generate edge function URLs from project URL
if config.SupabaseProjectURL != "" {
config.DeviceAuthURL = fmt.Sprintf("%s/functions/v1/device-auth", config.SupabaseProjectURL)
config.AgentAuthURL = fmt.Sprintf("%s/functions/v1/agent-auth-api", config.SupabaseProjectURL)
}
// Validate required configuration
if err := config.Validate(); err != nil {
return nil, fmt.Errorf("configuration validation failed: %w", err)
}
return &config, nil
}
// Validate checks if all required configuration is present
func (c *Config) Validate() error {
var missing []string
if c.SupabaseProjectURL == "" {
missing = append(missing, "SUPABASE_PROJECT_URL")
}
if c.DeviceAuthURL == "" {
missing = append(missing, "DEVICE_AUTH_URL (or SUPABASE_PROJECT_URL)")
}
if c.AgentAuthURL == "" {
missing = append(missing, "AGENT_AUTH_URL (or SUPABASE_PROJECT_URL)")
}
if len(missing) > 0 {
return fmt.Errorf("missing required environment variables: %s", strings.Join(missing, ", "))
}
return nil
}
// findEnvFile looks for .env file in current directory and parent directories
func findEnvFile() string {
dir, err := os.Getwd()
if err != nil {
return ""
}
for {
envPath := filepath.Join(dir, ".env")
if _, err := os.Stat(envPath); err == nil {
return envPath
}
parent := filepath.Dir(dir)
if parent == dir {
break
}
dir = parent
}
return ""
}
// PrintConfig prints the current configuration (masking sensitive values)
func (c *Config) PrintConfig() {
if !c.Debug {
return
}
fmt.Println("Configuration:")
fmt.Printf(" Supabase Project URL: %s\n", c.SupabaseProjectURL)
fmt.Printf(" Metrics Interval: %d seconds\n", c.MetricsInterval)
fmt.Printf(" Debug: %v\n", c.Debug)
}

View File

@@ -0,0 +1,90 @@
package logging
import (
"fmt"
"log"
"log/syslog"
"os"
)
type Logger struct {
syslogWriter *syslog.Writer
debugMode bool
}
var defaultLogger *Logger
func init() {
defaultLogger = NewLogger()
}
func NewLogger() *Logger {
l := &Logger{
debugMode: os.Getenv("DEBUG") == "true",
}
// Try to connect to syslog
if writer, err := syslog.New(syslog.LOG_INFO|syslog.LOG_DAEMON, "nannyagentv2"); err == nil {
l.syslogWriter = writer
}
return l
}
func (l *Logger) Info(format string, args ...interface{}) {
msg := fmt.Sprintf(format, args...)
if l.syslogWriter != nil {
l.syslogWriter.Info(msg)
}
log.Printf("[INFO] %s", msg)
}
func (l *Logger) Debug(format string, args ...interface{}) {
if !l.debugMode {
return
}
msg := fmt.Sprintf(format, args...)
if l.syslogWriter != nil {
l.syslogWriter.Debug(msg)
}
log.Printf("[DEBUG] %s", msg)
}
func (l *Logger) Warning(format string, args ...interface{}) {
msg := fmt.Sprintf(format, args...)
if l.syslogWriter != nil {
l.syslogWriter.Warning(msg)
}
log.Printf("[WARNING] %s", msg)
}
func (l *Logger) Error(format string, args ...interface{}) {
msg := fmt.Sprintf(format, args...)
if l.syslogWriter != nil {
l.syslogWriter.Err(msg)
}
log.Printf("[ERROR] %s", msg)
}
func (l *Logger) Close() {
if l.syslogWriter != nil {
l.syslogWriter.Close()
}
}
// Global logging functions
func Info(format string, args ...interface{}) {
defaultLogger.Info(format, args...)
}
func Debug(format string, args ...interface{}) {
defaultLogger.Debug(format, args...)
}
func Warning(format string, args ...interface{}) {
defaultLogger.Warning(format, args...)
}
func Error(format string, args ...interface{}) {
defaultLogger.Error(format, args...)
}

View File

@@ -0,0 +1,318 @@
package metrics
import (
"bytes"
"crypto/sha256"
"encoding/json"
"fmt"
"io"
"math"
"net/http"
"strings"
"time"
"github.com/shirou/gopsutil/v3/cpu"
"github.com/shirou/gopsutil/v3/disk"
"github.com/shirou/gopsutil/v3/host"
"github.com/shirou/gopsutil/v3/load"
"github.com/shirou/gopsutil/v3/mem"
psnet "github.com/shirou/gopsutil/v3/net"
"nannyagentv2/internal/types"
)
// Collector handles system metrics collection
type Collector struct {
agentVersion string
}
// NewCollector creates a new metrics collector
func NewCollector(agentVersion string) *Collector {
return &Collector{
agentVersion: agentVersion,
}
}
// GatherSystemMetrics collects comprehensive system metrics
func (c *Collector) GatherSystemMetrics() (*types.SystemMetrics, error) {
metrics := &types.SystemMetrics{
Timestamp: time.Now(),
}
// System Information
if hostInfo, err := host.Info(); err == nil {
metrics.Hostname = hostInfo.Hostname
metrics.Platform = hostInfo.Platform
metrics.PlatformFamily = hostInfo.PlatformFamily
metrics.PlatformVersion = hostInfo.PlatformVersion
metrics.KernelVersion = hostInfo.KernelVersion
metrics.KernelArch = hostInfo.KernelArch
}
// CPU Metrics
if percentages, err := cpu.Percent(time.Second, false); err == nil && len(percentages) > 0 {
metrics.CPUUsage = math.Round(percentages[0]*100) / 100
}
if cpuInfo, err := cpu.Info(); err == nil && len(cpuInfo) > 0 {
metrics.CPUCores = len(cpuInfo)
metrics.CPUModel = cpuInfo[0].ModelName
}
// Memory Metrics
if memInfo, err := mem.VirtualMemory(); err == nil {
metrics.MemoryUsage = math.Round(float64(memInfo.Used)/(1024*1024)*100) / 100 // MB
metrics.MemoryTotal = memInfo.Total
metrics.MemoryUsed = memInfo.Used
metrics.MemoryFree = memInfo.Free
metrics.MemoryAvailable = memInfo.Available
}
if swapInfo, err := mem.SwapMemory(); err == nil {
metrics.SwapTotal = swapInfo.Total
metrics.SwapUsed = swapInfo.Used
metrics.SwapFree = swapInfo.Free
}
// Disk Metrics
if diskInfo, err := disk.Usage("/"); err == nil {
metrics.DiskUsage = math.Round(diskInfo.UsedPercent*100) / 100
metrics.DiskTotal = diskInfo.Total
metrics.DiskUsed = diskInfo.Used
metrics.DiskFree = diskInfo.Free
}
// Load Averages
if loadAvg, err := load.Avg(); err == nil {
metrics.LoadAvg1 = math.Round(loadAvg.Load1*100) / 100
metrics.LoadAvg5 = math.Round(loadAvg.Load5*100) / 100
metrics.LoadAvg15 = math.Round(loadAvg.Load15*100) / 100
}
// Process Count (simplified - using a constant for now)
// Note: gopsutil doesn't have host.Processes(), would need process.Processes()
metrics.ProcessCount = 0 // Placeholder
// Network Metrics
netIn, netOut := c.getNetworkStats()
metrics.NetworkInKbps = netIn
metrics.NetworkOutKbps = netOut
if netIOCounters, err := psnet.IOCounters(false); err == nil && len(netIOCounters) > 0 {
netIO := netIOCounters[0]
metrics.NetworkInBytes = netIO.BytesRecv
metrics.NetworkOutBytes = netIO.BytesSent
}
// IP Address and Location
metrics.IPAddress = c.getIPAddress()
metrics.Location = c.getLocation() // Placeholder
// Filesystem Information
metrics.FilesystemInfo = c.getFilesystemInfo()
// Block Devices
metrics.BlockDevices = c.getBlockDevices()
return metrics, nil
}
// getNetworkStats returns network input/output rates in Kbps
func (c *Collector) getNetworkStats() (float64, float64) {
netIOCounters, err := psnet.IOCounters(false)
if err != nil || len(netIOCounters) == 0 {
return 0.0, 0.0
}
// Use the first interface for aggregate stats
netIO := netIOCounters[0]
// Convert bytes to kilobits per second (simplified - cumulative bytes to kilobits)
netInKbps := float64(netIO.BytesRecv) * 8 / 1024
netOutKbps := float64(netIO.BytesSent) * 8 / 1024
return netInKbps, netOutKbps
}
// getIPAddress returns the primary IP address of the system
func (c *Collector) getIPAddress() string {
interfaces, err := psnet.Interfaces()
if err != nil {
return "unknown"
}
for _, iface := range interfaces {
if len(iface.Addrs) > 0 && !strings.Contains(iface.Addrs[0].Addr, "127.0.0.1") {
return strings.Split(iface.Addrs[0].Addr, "/")[0] // Remove CIDR if present
}
}
return "unknown"
}
// getLocation returns basic location information (placeholder)
func (c *Collector) getLocation() string {
return "unknown" // Would integrate with GeoIP service
}
// getFilesystemInfo returns information about mounted filesystems
func (c *Collector) getFilesystemInfo() []types.FilesystemInfo {
partitions, err := disk.Partitions(false)
if err != nil {
return []types.FilesystemInfo{}
}
var filesystems []types.FilesystemInfo
for _, partition := range partitions {
usage, err := disk.Usage(partition.Mountpoint)
if err != nil {
continue
}
fs := types.FilesystemInfo{
Mountpoint: partition.Mountpoint,
Fstype: partition.Fstype,
Total: usage.Total,
Used: usage.Used,
Free: usage.Free,
UsagePercent: math.Round(usage.UsedPercent*100) / 100,
}
filesystems = append(filesystems, fs)
}
return filesystems
}
// getBlockDevices returns information about block devices
func (c *Collector) getBlockDevices() []types.BlockDevice {
partitions, err := disk.Partitions(true)
if err != nil {
return []types.BlockDevice{}
}
var devices []types.BlockDevice
deviceMap := make(map[string]bool)
for _, partition := range partitions {
// Only include actual block devices
if strings.HasPrefix(partition.Device, "/dev/") {
deviceName := partition.Device
if !deviceMap[deviceName] {
deviceMap[deviceName] = true
device := types.BlockDevice{
Name: deviceName,
Model: "unknown",
Size: 0,
SerialNumber: "unknown",
}
devices = append(devices, device)
}
}
}
return devices
}
// SendMetrics sends system metrics to the agent-auth-api endpoint
func (c *Collector) SendMetrics(agentAuthURL, accessToken, agentID string, metrics *types.SystemMetrics) error {
// Create flattened metrics request for agent-auth-api
metricsReq := c.CreateMetricsRequest(agentID, metrics)
return c.sendMetricsRequest(agentAuthURL, accessToken, metricsReq)
}
// CreateMetricsRequest converts SystemMetrics to the flattened format expected by agent-auth-api
func (c *Collector) CreateMetricsRequest(agentID string, systemMetrics *types.SystemMetrics) *types.MetricsRequest {
return &types.MetricsRequest{
AgentID: agentID,
CPUUsage: systemMetrics.CPUUsage,
MemoryUsage: systemMetrics.MemoryUsage,
DiskUsage: systemMetrics.DiskUsage,
NetworkInKbps: systemMetrics.NetworkInKbps,
NetworkOutKbps: systemMetrics.NetworkOutKbps,
IPAddress: systemMetrics.IPAddress,
Location: systemMetrics.Location,
AgentVersion: c.agentVersion,
KernelVersion: systemMetrics.KernelVersion,
DeviceFingerprint: c.generateDeviceFingerprint(systemMetrics),
LoadAverages: map[string]float64{
"load1": systemMetrics.LoadAvg1,
"load5": systemMetrics.LoadAvg5,
"load15": systemMetrics.LoadAvg15,
},
OSInfo: map[string]string{
"cpu_cores": fmt.Sprintf("%d", systemMetrics.CPUCores),
"memory": fmt.Sprintf("%.1fGi", float64(systemMetrics.MemoryTotal)/(1024*1024*1024)),
"uptime": "unknown", // Will be calculated by the server or client
"platform": systemMetrics.Platform,
"platform_family": systemMetrics.PlatformFamily,
"platform_version": systemMetrics.PlatformVersion,
"kernel_version": systemMetrics.KernelVersion,
"kernel_arch": systemMetrics.KernelArch,
},
FilesystemInfo: systemMetrics.FilesystemInfo,
BlockDevices: systemMetrics.BlockDevices,
NetworkStats: map[string]uint64{
"bytes_sent": systemMetrics.NetworkOutBytes,
"bytes_recv": systemMetrics.NetworkInBytes,
"total_bytes": systemMetrics.NetworkInBytes + systemMetrics.NetworkOutBytes,
},
}
}
// sendMetricsRequest sends the metrics request to the agent-auth-api
func (c *Collector) sendMetricsRequest(agentAuthURL, accessToken string, metricsReq *types.MetricsRequest) error {
// Wrap metrics in the expected payload structure
payload := map[string]interface{}{
"metrics": metricsReq,
"timestamp": time.Now().UTC().Format(time.RFC3339),
}
jsonData, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("failed to marshal metrics: %w", err)
}
// Send to /metrics endpoint
metricsURL := fmt.Sprintf("%s/metrics", agentAuthURL)
req, err := http.NewRequest("POST", metricsURL, bytes.NewBuffer(jsonData))
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken))
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("failed to send metrics: %w", err)
}
defer resp.Body.Close()
// Read response
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read response: %w", err)
}
// Check response status
if resp.StatusCode == http.StatusUnauthorized {
return fmt.Errorf("unauthorized")
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("metrics request failed with status %d: %s", resp.StatusCode, string(body))
}
return nil
}
// generateDeviceFingerprint creates a unique device identifier
func (c *Collector) generateDeviceFingerprint(metrics *types.SystemMetrics) string {
fingerprint := fmt.Sprintf("%s-%s-%s", metrics.Hostname, metrics.Platform, metrics.KernelVersion)
hasher := sha256.New()
hasher.Write([]byte(fingerprint))
return fmt.Sprintf("%x", hasher.Sum(nil))[:16]
}

337
internal/types/types.go Normal file
View File

@@ -0,0 +1,337 @@
package types
import "time"
// SystemMetrics represents comprehensive system performance metrics
type SystemMetrics struct {
// System Information
Hostname string `json:"hostname"`
Platform string `json:"platform"`
PlatformFamily string `json:"platform_family"`
PlatformVersion string `json:"platform_version"`
KernelVersion string `json:"kernel_version"`
KernelArch string `json:"kernel_arch"`
// CPU Metrics
CPUUsage float64 `json:"cpu_usage"`
CPUCores int `json:"cpu_cores"`
CPUModel string `json:"cpu_model"`
// Memory Metrics
MemoryUsage float64 `json:"memory_usage"`
MemoryTotal uint64 `json:"memory_total"`
MemoryUsed uint64 `json:"memory_used"`
MemoryFree uint64 `json:"memory_free"`
MemoryAvailable uint64 `json:"memory_available"`
SwapTotal uint64 `json:"swap_total"`
SwapUsed uint64 `json:"swap_used"`
SwapFree uint64 `json:"swap_free"`
// Disk Metrics
DiskUsage float64 `json:"disk_usage"`
DiskTotal uint64 `json:"disk_total"`
DiskUsed uint64 `json:"disk_used"`
DiskFree uint64 `json:"disk_free"`
// Network Metrics
NetworkInKbps float64 `json:"network_in_kbps"`
NetworkOutKbps float64 `json:"network_out_kbps"`
NetworkInBytes uint64 `json:"network_in_bytes"`
NetworkOutBytes uint64 `json:"network_out_bytes"`
// System Load
LoadAvg1 float64 `json:"load_avg_1"`
LoadAvg5 float64 `json:"load_avg_5"`
LoadAvg15 float64 `json:"load_avg_15"`
// Process Information
ProcessCount int `json:"process_count"`
// Network Information
IPAddress string `json:"ip_address"`
Location string `json:"location"`
// Filesystem Information
FilesystemInfo []FilesystemInfo `json:"filesystem_info"`
BlockDevices []BlockDevice `json:"block_devices"`
// Timestamp
Timestamp time.Time `json:"timestamp"`
}
// FilesystemInfo represents individual filesystem statistics
type FilesystemInfo struct {
Mountpoint string `json:"mountpoint"`
Fstype string `json:"fstype"`
Total uint64 `json:"total"`
Used uint64 `json:"used"`
Free uint64 `json:"free"`
UsagePercent float64 `json:"usage_percent"`
}
// BlockDevice represents block device information
type BlockDevice struct {
Name string `json:"name"`
Size uint64 `json:"size"`
Model string `json:"model"`
SerialNumber string `json:"serial_number"`
}
// NetworkStats represents detailed network interface statistics
type NetworkStats struct {
InterfaceName string `json:"interface_name"`
BytesSent uint64 `json:"bytes_sent"`
BytesRecv uint64 `json:"bytes_recv"`
PacketsSent uint64 `json:"packets_sent"`
PacketsRecv uint64 `json:"packets_recv"`
ErrorsIn uint64 `json:"errors_in"`
ErrorsOut uint64 `json:"errors_out"`
DropsIn uint64 `json:"drops_in"`
DropsOut uint64 `json:"drops_out"`
}
// AuthToken represents the authentication token structure
type AuthToken struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresAt time.Time `json:"expires_at"`
TokenType string `json:"token_type"`
AgentID string `json:"agent_id"`
}
// DeviceAuthRequest represents the device authorization request
type DeviceAuthRequest struct {
ClientID string `json:"client_id"`
Scope string `json:"scope,omitempty"`
}
// DeviceAuthResponse represents the device authorization response
type DeviceAuthResponse struct {
DeviceCode string `json:"device_code"`
UserCode string `json:"user_code"`
VerificationURI string `json:"verification_uri"`
ExpiresIn int `json:"expires_in"`
Interval int `json:"interval"`
}
// TokenRequest represents the token request for device flow
type TokenRequest struct {
GrantType string `json:"grant_type"`
DeviceCode string `json:"device_code,omitempty"`
RefreshToken string `json:"refresh_token,omitempty"`
ClientID string `json:"client_id,omitempty"`
}
// TokenResponse represents the token response
type TokenResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
AgentID string `json:"agent_id,omitempty"`
Error string `json:"error,omitempty"`
ErrorDescription string `json:"error_description,omitempty"`
}
// HeartbeatRequest represents the agent heartbeat request
type HeartbeatRequest struct {
AgentID string `json:"agent_id"`
Status string `json:"status"`
Metrics SystemMetrics `json:"metrics"`
}
// MetricsRequest represents the flattened metrics payload expected by agent-auth-api
type MetricsRequest struct {
// Agent identification
AgentID string `json:"agent_id"`
// Basic metrics
CPUUsage float64 `json:"cpu_usage"`
MemoryUsage float64 `json:"memory_usage"`
DiskUsage float64 `json:"disk_usage"`
// Network metrics
NetworkInKbps float64 `json:"network_in_kbps"`
NetworkOutKbps float64 `json:"network_out_kbps"`
// System information
IPAddress string `json:"ip_address"`
Location string `json:"location"`
AgentVersion string `json:"agent_version"`
KernelVersion string `json:"kernel_version"`
DeviceFingerprint string `json:"device_fingerprint"`
// Structured data (JSON fields in database)
LoadAverages map[string]float64 `json:"load_averages"`
OSInfo map[string]string `json:"os_info"`
FilesystemInfo []FilesystemInfo `json:"filesystem_info"`
BlockDevices []BlockDevice `json:"block_devices"`
NetworkStats map[string]uint64 `json:"network_stats"`
}
// eBPF related types
type EBPFEvent struct {
Timestamp int64 `json:"timestamp"`
EventType string `json:"event_type"`
ProcessID int `json:"process_id"`
ProcessName string `json:"process_name"`
UserID int `json:"user_id"`
Data map[string]interface{} `json:"data"`
}
type EBPFTrace struct {
TraceID string `json:"trace_id"`
StartTime time.Time `json:"start_time"`
EndTime time.Time `json:"end_time"`
Capability string `json:"capability"`
Events []EBPFEvent `json:"events"`
Summary string `json:"summary"`
EventCount int `json:"event_count"`
ProcessList []string `json:"process_list"`
}
type EBPFRequest struct {
Name string `json:"name"`
Type string `json:"type"` // "tracepoint", "kprobe", "kretprobe"
Target string `json:"target"` // tracepoint path or function name
Duration int `json:"duration"` // seconds
Filters map[string]string `json:"filters,omitempty"`
Description string `json:"description"`
}
type NetworkEvent struct {
Timestamp uint64 `json:"timestamp"`
PID uint32 `json:"pid"`
TID uint32 `json:"tid"`
UID uint32 `json:"uid"`
EventType string `json:"event_type"`
Comm [16]byte `json:"-"`
CommStr string `json:"comm"`
}
// Agent types
type DiagnosticResponse struct {
ResponseType string `json:"response_type"`
Reasoning string `json:"reasoning"`
Commands []Command `json:"commands"`
}
type ResolutionResponse struct {
ResponseType string `json:"response_type"`
RootCause string `json:"root_cause"`
ResolutionPlan string `json:"resolution_plan"`
Confidence string `json:"confidence"`
}
type Command struct {
ID string `json:"id"`
Command string `json:"command"`
Description string `json:"description"`
}
type CommandResult struct {
ID string `json:"id"`
Command string `json:"command"`
Description string `json:"description"`
Output string `json:"output"`
ExitCode int `json:"exit_code"`
Error string `json:"error,omitempty"`
}
type EBPFEnhancedDiagnosticResponse struct {
ResponseType string `json:"response_type"`
Reasoning string `json:"reasoning"`
Commands []Command `json:"commands"`
EBPFPrograms []EBPFRequest `json:"ebpf_programs"`
NextActions []string `json:"next_actions,omitempty"`
}
type TensorZeroRequest struct {
Model string `json:"model"`
Messages []map[string]interface{} `json:"messages"`
EpisodeID string `json:"tensorzero::episode_id,omitempty"`
}
type TensorZeroResponse struct {
Choices []map[string]interface{} `json:"choices"`
EpisodeID string `json:"episode_id"`
}
// WebSocket types
type WebSocketMessage struct {
Type string `json:"type"`
Data interface{} `json:"data"`
}
type InvestigationTask struct {
TaskID string `json:"task_id"`
InvestigationID string `json:"investigation_id"`
AgentID string `json:"agent_id"`
DiagnosticPayload map[string]interface{} `json:"diagnostic_payload"`
EpisodeID string `json:"episode_id,omitempty"`
}
type TaskResult struct {
TaskID string `json:"task_id"`
Success bool `json:"success"`
CommandResults map[string]interface{} `json:"command_results,omitempty"`
Error string `json:"error,omitempty"`
}
type HeartbeatData struct {
AgentID string `json:"agent_id"`
Timestamp time.Time `json:"timestamp"`
Version string `json:"version"`
}
// Investigation server types
type InvestigationRequest struct {
Issue string `json:"issue"`
AgentID string `json:"agent_id"`
EpisodeID string `json:"episode_id,omitempty"`
Timestamp string `json:"timestamp,omitempty"`
Priority string `json:"priority,omitempty"`
Description string `json:"description,omitempty"`
}
type InvestigationResponse struct {
Status string `json:"status"`
Message string `json:"message"`
Results map[string]interface{} `json:"results,omitempty"`
AgentID string `json:"agent_id"`
Timestamp string `json:"timestamp"`
EpisodeID string `json:"episode_id,omitempty"`
Investigation *PendingInvestigation `json:"investigation,omitempty"`
}
type PendingInvestigation struct {
ID string `json:"id"`
Issue string `json:"issue"`
AgentID string `json:"agent_id"`
Status string `json:"status"`
DiagnosticPayload map[string]interface{} `json:"diagnostic_payload"`
CommandResults map[string]interface{} `json:"command_results,omitempty"`
EpisodeID *string `json:"episode_id,omitempty"`
CreatedAt string `json:"created_at"`
StartedAt *string `json:"started_at,omitempty"`
CompletedAt *string `json:"completed_at,omitempty"`
ErrorMessage *string `json:"error_message,omitempty"`
}
// System types
type SystemInfo struct {
Hostname string `json:"hostname"`
Platform string `json:"platform"`
PlatformInfo map[string]string `json:"platform_info"`
KernelVersion string `json:"kernel_version"`
Uptime string `json:"uptime"`
LoadAverage []float64 `json:"load_average"`
CPUInfo map[string]string `json:"cpu_info"`
MemoryInfo map[string]string `json:"memory_info"`
DiskInfo []map[string]string `json:"disk_info"`
}
// Executor types
type CommandExecutor struct {
timeout time.Duration
}

527
investigation_server.go Normal file
View File

@@ -0,0 +1,527 @@
package main
import (
"encoding/json"
"fmt"
"net/http"
"os"
"strings"
"time"
"nannyagentv2/internal/auth"
"nannyagentv2/internal/metrics"
"github.com/sashabaranov/go-openai"
)
// InvestigationRequest represents a request from Supabase to start an investigation
type InvestigationRequest struct {
InvestigationID string `json:"investigation_id"`
ApplicationGroup string `json:"application_group"`
Issue string `json:"issue"`
Context map[string]string `json:"context"`
Priority string `json:"priority"`
InitiatedBy string `json:"initiated_by"`
}
// InvestigationResponse represents the agent's response to an investigation
type InvestigationResponse struct {
AgentID string `json:"agent_id"`
InvestigationID string `json:"investigation_id"`
Status string `json:"status"`
Commands []CommandResult `json:"commands,omitempty"`
AIResponse string `json:"ai_response,omitempty"`
EpisodeID string `json:"episode_id,omitempty"`
Timestamp time.Time `json:"timestamp"`
Error string `json:"error,omitempty"`
}
// InvestigationServer handles reverse investigation requests from Supabase
type InvestigationServer struct {
agent *LinuxDiagnosticAgent // Original agent for direct user interactions
applicationAgent *LinuxDiagnosticAgent // Separate agent for application-initiated investigations
port string
agentID string
metricsCollector *metrics.Collector
authManager *auth.AuthManager
startTime time.Time
supabaseURL string
}
// NewInvestigationServer creates a new investigation server
func NewInvestigationServer(agent *LinuxDiagnosticAgent, authManager *auth.AuthManager) *InvestigationServer {
port := os.Getenv("AGENT_PORT")
if port == "" {
port = "1234"
}
// Get agent ID from authentication system
var agentID string
if authManager != nil {
if id, err := authManager.GetCurrentAgentID(); err == nil {
agentID = id
} else {
fmt.Printf("❌ Failed to get agent ID from auth manager: %v\n", err)
}
}
// Fallback to environment variable or generate one if auth fails
if agentID == "" {
agentID = os.Getenv("AGENT_ID")
if agentID == "" {
agentID = fmt.Sprintf("agent-%d", time.Now().Unix())
}
}
// Create metrics collector
metricsCollector := metrics.NewCollector("v2.0.0")
// Create a separate agent for application-initiated investigations
applicationAgent := NewLinuxDiagnosticAgent()
// Override the model to use the application-specific function
applicationAgent.model = "tensorzero::function_name::diagnose_and_heal_application"
return &InvestigationServer{
agent: agent,
applicationAgent: applicationAgent,
port: port,
agentID: agentID,
metricsCollector: metricsCollector,
authManager: authManager,
startTime: time.Now(),
supabaseURL: os.Getenv("SUPABASE_PROJECT_URL"),
}
}
// DiagnoseIssueForApplication handles diagnostic requests initiated from application/portal
func (s *InvestigationServer) DiagnoseIssueForApplication(issue, episodeID string) error {
// Set the episode ID on the application agent for continuity
s.applicationAgent.episodeID = episodeID
return s.applicationAgent.DiagnoseIssue(issue)
}
// Start starts the HTTP server and realtime polling for investigation requests
func (s *InvestigationServer) Start() error {
mux := http.NewServeMux()
// Health check endpoint
mux.HandleFunc("/health", s.handleHealth)
// Investigation endpoint
mux.HandleFunc("/investigate", s.handleInvestigation)
// Agent status endpoint
mux.HandleFunc("/status", s.handleStatus)
// Start realtime polling for backend-initiated investigations
if s.supabaseURL != "" && s.authManager != nil {
go s.startRealtimePolling()
fmt.Printf("🔄 Realtime investigation polling enabled\n")
} else {
fmt.Printf("⚠️ Realtime investigation polling disabled (missing Supabase config or auth)\n")
}
server := &http.Server{
Addr: ":" + s.port,
Handler: mux,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
}
fmt.Printf("🔍 Investigation server started on port %s (Agent ID: %s)\n", s.port, s.agentID)
return server.ListenAndServe()
}
// handleHealth responds to health check requests
func (s *InvestigationServer) handleHealth(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
response := map[string]interface{}{
"status": "healthy",
"agent_id": s.agentID,
"timestamp": time.Now(),
"version": "v2.0.0",
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}
// handleStatus responds with agent status and capabilities
func (s *InvestigationServer) handleStatus(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// Collect current system metrics
systemMetrics, err := s.metricsCollector.GatherSystemMetrics()
if err != nil {
http.Error(w, fmt.Sprintf("Failed to collect metrics: %v", err), http.StatusInternalServerError)
return
}
// Convert to metrics request format for consistent data structure
metricsReq := s.metricsCollector.CreateMetricsRequest(s.agentID, systemMetrics)
response := map[string]interface{}{
"agent_id": s.agentID,
"status": "ready",
"capabilities": []string{"system_diagnostics", "ebpf_monitoring", "command_execution", "ai_analysis"},
"system_info": map[string]interface{}{
"os": fmt.Sprintf("%s %s", metricsReq.OSInfo["platform"], metricsReq.OSInfo["platform_version"]),
"kernel": metricsReq.KernelVersion,
"architecture": metricsReq.OSInfo["kernel_arch"],
"cpu_cores": metricsReq.OSInfo["cpu_cores"],
"memory": metricsReq.MemoryUsage,
"private_ips": metricsReq.IPAddress,
"load_average": fmt.Sprintf("%.2f, %.2f, %.2f",
metricsReq.LoadAverages["load1"],
metricsReq.LoadAverages["load5"],
metricsReq.LoadAverages["load15"]),
"disk_usage": fmt.Sprintf("Root: %.0fG/%.0fG (%.0f%% used)",
float64(metricsReq.FilesystemInfo[0].Used)/1024/1024/1024,
float64(metricsReq.FilesystemInfo[0].Total)/1024/1024/1024,
metricsReq.DiskUsage),
},
"uptime": time.Since(s.startTime),
"last_contact": time.Now(),
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}
// sendCommandResultsToTensorZero sends command results back to TensorZero and continues conversation
func (s *InvestigationServer) sendCommandResultsToTensorZero(diagnosticResp DiagnosticResponse, commandResults []CommandResult) (interface{}, error) {
// Build conversation history like in agent.go
messages := []openai.ChatCompletionMessage{
// Add the original diagnostic response as assistant message
{
Role: openai.ChatMessageRoleAssistant,
Content: fmt.Sprintf(`{"response_type":"diagnostic","reasoning":"%s","commands":%s}`,
diagnosticResp.Reasoning,
mustMarshalJSON(diagnosticResp.Commands)),
},
}
// Add command results as user message (same as agent.go does)
resultsJSON, err := json.MarshalIndent(commandResults, "", " ")
if err != nil {
return nil, fmt.Errorf("failed to marshal command results: %w", err)
}
messages = append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleUser,
Content: string(resultsJSON),
})
// Send to TensorZero via application agent's sendRequest method
fmt.Printf("🔄 Sending command results to TensorZero for analysis...\n")
response, err := s.applicationAgent.sendRequest(messages)
if err != nil {
return nil, fmt.Errorf("failed to send request to TensorZero: %w", err)
}
if len(response.Choices) == 0 {
return nil, fmt.Errorf("no choices in TensorZero response")
}
content := response.Choices[0].Message.Content
fmt.Printf("🤖 TensorZero continued analysis:\n%s\n", content)
// Try to parse the response to determine if it's diagnostic or resolution
var diagnosticNextResp DiagnosticResponse
var resolutionResp ResolutionResponse
// Check if it's another diagnostic response
if err := json.Unmarshal([]byte(content), &diagnosticNextResp); err == nil && diagnosticNextResp.ResponseType == "diagnostic" {
fmt.Printf("🔄 TensorZero requests %d more commands\n", len(diagnosticNextResp.Commands))
return map[string]interface{}{
"type": "diagnostic",
"response": diagnosticNextResp,
"raw": content,
}, nil
}
// Check if it's a resolution response
if err := json.Unmarshal([]byte(content), &resolutionResp); err == nil && resolutionResp.ResponseType == "resolution" {
return map[string]interface{}{
"type": "resolution",
"response": resolutionResp,
"raw": content,
}, nil
}
// Return raw response if we can't parse it
return map[string]interface{}{
"type": "unknown",
"raw": content,
}, nil
}
// Helper function to marshal JSON without errors
func mustMarshalJSON(v interface{}) string {
data, _ := json.Marshal(v)
return string(data)
}
// processInvestigation handles the actual investigation using TensorZero
// This endpoint receives either:
// 1. DiagnosticResponse - Commands and eBPF programs to execute
// 2. ResolutionResponse - Final resolution (no execution needed)
func (s *InvestigationServer) handleInvestigation(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed - only POST accepted", http.StatusMethodNotAllowed)
return
}
// Parse the request body to determine what type of response this is
var requestBody map[string]interface{}
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
http.Error(w, fmt.Sprintf("Invalid JSON: %v", err), http.StatusBadRequest)
return
}
// Check the response_type field to determine how to handle this
responseType, ok := requestBody["response_type"].(string)
if !ok {
http.Error(w, "Missing or invalid response_type field", http.StatusBadRequest)
return
}
fmt.Printf("📋 Received investigation payload with response_type: %s\n", responseType)
switch responseType {
case "diagnostic":
// This is a DiagnosticResponse with commands to execute
response := s.handleDiagnosticExecution(requestBody)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
case "resolution":
// This is a ResolutionResponse - final result, just acknowledge
fmt.Printf("📋 Received final resolution from backend\n")
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"success": true,
"message": "Resolution received and acknowledged",
"agent_id": s.agentID,
})
default:
http.Error(w, fmt.Sprintf("Unknown response_type: %s", responseType), http.StatusBadRequest)
return
}
}
// handleDiagnosticExecution executes commands from a DiagnosticResponse
func (s *InvestigationServer) handleDiagnosticExecution(requestBody map[string]interface{}) map[string]interface{} {
// Parse as DiagnosticResponse
var diagnosticResp DiagnosticResponse
// Convert the map back to JSON and then parse it properly
jsonData, err := json.Marshal(requestBody)
if err != nil {
return map[string]interface{}{
"success": false,
"error": fmt.Sprintf("Failed to re-marshal request: %v", err),
"agent_id": s.agentID,
}
}
if err := json.Unmarshal(jsonData, &diagnosticResp); err != nil {
return map[string]interface{}{
"success": false,
"error": fmt.Sprintf("Failed to parse DiagnosticResponse: %v", err),
"agent_id": s.agentID,
}
}
fmt.Printf("📋 Executing %d commands from backend\n", len(diagnosticResp.Commands))
// Execute all commands
commandResults := make([]CommandResult, 0, len(diagnosticResp.Commands))
for _, cmd := range diagnosticResp.Commands {
fmt.Printf("⚙️ Executing command '%s': %s\n", cmd.ID, cmd.Command)
// Use the agent's executor to run the command
result := s.agent.executor.Execute(cmd)
commandResults = append(commandResults, result)
if result.Error != "" {
fmt.Printf("⚠️ Command '%s' had error: %s\n", cmd.ID, result.Error)
}
}
// Send command results back to TensorZero for continued analysis
fmt.Printf("🔄 Sending %d command results back to TensorZero for continued analysis\n", len(commandResults))
nextResponse, err := s.sendCommandResultsToTensorZero(diagnosticResp, commandResults)
if err != nil {
return map[string]interface{}{
"success": false,
"error": fmt.Sprintf("Failed to continue TensorZero conversation: %v", err),
"agent_id": s.agentID,
"command_results": commandResults, // Still return the results
}
}
// Return both the command results and the next response from TensorZero
return map[string]interface{}{
"success": true,
"agent_id": s.agentID,
"command_results": commandResults,
"commands_executed": len(commandResults),
"next_response": nextResponse,
"timestamp": time.Now().Format(time.RFC3339),
}
}
// PendingInvestigation represents a pending investigation from the database
type PendingInvestigation struct {
ID string `json:"id"`
InvestigationID string `json:"investigation_id"`
AgentID string `json:"agent_id"`
DiagnosticPayload map[string]interface{} `json:"diagnostic_payload"`
EpisodeID *string `json:"episode_id"`
Status string `json:"status"`
CreatedAt time.Time `json:"created_at"`
}
// startRealtimePolling begins polling for pending investigations
func (s *InvestigationServer) startRealtimePolling() {
fmt.Printf("🔄 Starting realtime investigation polling for agent %s\n", s.agentID)
ticker := time.NewTicker(5 * time.Second) // Poll every 5 seconds
defer ticker.Stop()
for range ticker.C {
s.checkForPendingInvestigations()
}
}
// checkForPendingInvestigations checks for new pending investigations
func (s *InvestigationServer) checkForPendingInvestigations() {
url := fmt.Sprintf("%s/rest/v1/pending_investigations?agent_id=eq.%s&status=eq.pending&order=created_at.desc",
s.supabaseURL, s.agentID)
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return // Silent fail for polling
}
// Get token from auth manager
authToken, err := s.authManager.LoadToken()
if err != nil {
return // Silent fail for polling
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", authToken.AccessToken))
req.Header.Set("Accept", "application/json")
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return // Silent fail for polling
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
return // Silent fail for polling
}
var investigations []PendingInvestigation
err = json.NewDecoder(resp.Body).Decode(&investigations)
if err != nil {
return // Silent fail for polling
}
for _, investigation := range investigations {
fmt.Printf("🔍 Found pending investigation: %s\n", investigation.ID)
go s.handlePendingInvestigation(investigation)
}
}
// handlePendingInvestigation processes a single pending investigation
func (s *InvestigationServer) handlePendingInvestigation(investigation PendingInvestigation) {
fmt.Printf("🚀 Processing realtime investigation %s\n", investigation.InvestigationID)
// Mark as executing
err := s.updateInvestigationStatus(investigation.ID, "executing", nil, nil)
if err != nil {
fmt.Printf("❌ Failed to mark investigation as executing: %v\n", err)
return
}
// Execute diagnostic commands using existing handleDiagnosticExecution method
results := s.handleDiagnosticExecution(investigation.DiagnosticPayload)
// Mark as completed with results
err = s.updateInvestigationStatus(investigation.ID, "completed", results, nil)
if err != nil {
fmt.Printf("❌ Failed to mark investigation as completed: %v\n", err)
return
}
}
// updateInvestigationStatus updates the status of a pending investigation
func (s *InvestigationServer) updateInvestigationStatus(id, status string, results map[string]interface{}, errorMsg *string) error {
updateData := map[string]interface{}{
"status": status,
}
if status == "executing" {
updateData["started_at"] = time.Now().UTC().Format(time.RFC3339)
} else if status == "completed" {
updateData["completed_at"] = time.Now().UTC().Format(time.RFC3339)
if results != nil {
updateData["command_results"] = results
}
} else if status == "failed" && errorMsg != nil {
updateData["error_message"] = *errorMsg
updateData["completed_at"] = time.Now().UTC().Format(time.RFC3339)
}
jsonData, err := json.Marshal(updateData)
if err != nil {
return fmt.Errorf("failed to marshal update data: %v", err)
}
url := fmt.Sprintf("%s/rest/v1/pending_investigations?id=eq.%s", s.supabaseURL, id)
req, err := http.NewRequest("PATCH", url, strings.NewReader(string(jsonData)))
if err != nil {
return fmt.Errorf("failed to create request: %v", err)
}
// Get token from auth manager
authToken, err := s.authManager.LoadToken()
if err != nil {
return fmt.Errorf("failed to load auth token: %v", err)
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", authToken.AccessToken))
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("failed to update investigation: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 && resp.StatusCode != 204 {
return fmt.Errorf("supabase update error: %d", resp.StatusCode)
}
return nil
}

147
main.go
View File

@@ -9,8 +9,16 @@ import (
"strconv"
"strings"
"syscall"
"time"
"nannyagentv2/internal/auth"
"nannyagentv2/internal/config"
"nannyagentv2/internal/metrics"
"nannyagentv2/internal/types"
)
const Version = "v2.0.0"
// checkRootPrivileges ensures the program is running as root
func checkRootPrivileges() {
if os.Geteuid() != 0 {
@@ -65,7 +73,7 @@ func checkKernelVersionCompatibility() {
os.Exit(1)
}
fmt.Printf("✅ Kernel version %s is compatible with eBPF\n", kernelVersion)
}
// checkEBPFSupport validates eBPF subsystem availability
@@ -89,27 +97,14 @@ func checkEBPFSupport() {
syscall.Close(int(fd))
}
fmt.Printf("✅ eBPF syscall is available\n")
}
func main() {
// runInteractiveDiagnostics starts the interactive diagnostic session
func runInteractiveDiagnostics(agent *LinuxDiagnosticAgent) {
fmt.Println("")
fmt.Println("🔍 Linux eBPF-Enhanced Diagnostic Agent")
fmt.Println("=======================================")
// Perform system compatibility checks
fmt.Println("Performing system compatibility checks...")
checkRootPrivileges()
checkKernelVersionCompatibility()
checkEBPFSupport()
fmt.Println("✅ All system checks passed")
fmt.Println("")
// Initialize the agent
agent := NewLinuxDiagnosticAgent()
// Start the interactive session
fmt.Println("Linux Diagnostic Agent Started")
fmt.Println("Enter a system issue description (or 'quit' to exit):")
@@ -129,8 +124,8 @@ func main() {
continue
}
// Process the issue with eBPF capabilities
if err := agent.DiagnoseWithEBPF(input); err != nil {
// Process the issue with AI capabilities via TensorZero
if err := agent.DiagnoseIssue(input); err != nil {
fmt.Printf("Error: %v\n", err)
}
}
@@ -141,3 +136,115 @@ func main() {
fmt.Println("Goodbye!")
}
func main() {
fmt.Printf("🚀 NannyAgent v%s starting...\n", Version)
// Perform system compatibility checks first
fmt.Println("Performing system compatibility checks...")
checkRootPrivileges()
checkKernelVersionCompatibility()
checkEBPFSupport()
fmt.Println("✅ All system checks passed")
fmt.Println("")
// Load configuration
cfg, err := config.LoadConfig()
if err != nil {
log.Fatalf("❌ Failed to load configuration: %v", err)
}
cfg.PrintConfig()
// Initialize components
authManager := auth.NewAuthManager(cfg)
metricsCollector := metrics.NewCollector(Version)
// Ensure authentication
token, err := authManager.EnsureAuthenticated()
if err != nil {
log.Fatalf("❌ Authentication failed: %v", err)
}
fmt.Println("✅ Authentication successful!")
// Initialize the diagnostic agent for interactive CLI use
agent := NewLinuxDiagnosticAgent()
// Initialize a separate agent for WebSocket investigations using the application model
applicationAgent := NewLinuxDiagnosticAgent()
applicationAgent.model = "tensorzero::function_name::diagnose_and_heal_application"
// Start WebSocket client for backend communications and investigations
wsClient := NewWebSocketClient(applicationAgent, authManager)
go func() {
if err := wsClient.Start(); err != nil {
log.Printf("❌ WebSocket client error: %v", err)
}
}()
// Start background metrics collection in a goroutine
go func() {
fmt.Println("❤️ Starting background metrics collection and heartbeat...")
ticker := time.NewTicker(time.Duration(cfg.MetricsInterval) * time.Second)
defer ticker.Stop()
// Send initial heartbeat
if err := sendHeartbeat(cfg, token, metricsCollector); err != nil {
log.Printf("⚠️ Initial heartbeat failed: %v", err)
}
// Main heartbeat loop
for range ticker.C {
// Check if token needs refresh
if authManager.IsTokenExpired(token) {
fmt.Println("🔄 Token expiring soon, refreshing...")
newToken, refreshErr := authManager.EnsureAuthenticated()
if refreshErr != nil {
log.Printf("❌ Token refresh failed: %v", refreshErr)
continue
}
token = newToken
fmt.Println("✅ Token refreshed successfully")
}
// Send heartbeat
if err := sendHeartbeat(cfg, token, metricsCollector); err != nil {
log.Printf("⚠️ Heartbeat failed: %v", err)
// If unauthorized, try to refresh token
if err.Error() == "unauthorized" {
fmt.Println("🔄 Unauthorized, attempting token refresh...")
newToken, refreshErr := authManager.EnsureAuthenticated()
if refreshErr != nil {
log.Printf("❌ Token refresh failed: %v", refreshErr)
continue
}
token = newToken
// Retry heartbeat with new token (silently)
if retryErr := sendHeartbeat(cfg, token, metricsCollector); retryErr != nil {
log.Printf("⚠️ Retry heartbeat failed: %v", retryErr)
}
}
}
// No logging for successful heartbeats - they should be silent
}
}()
// Start the interactive diagnostic session (blocking)
runInteractiveDiagnostics(agent)
}
// sendHeartbeat collects metrics and sends heartbeat to the server
func sendHeartbeat(cfg *config.Config, token *types.AuthToken, collector *metrics.Collector) error {
// Collect system metrics
systemMetrics, err := collector.GatherSystemMetrics()
if err != nil {
return fmt.Errorf("failed to gather system metrics: %w", err)
}
// Send metrics using the collector with correct agent_id from token
return collector.SendMetrics(cfg.AgentAuthURL, token.AccessToken, token.AgentID, systemMetrics)
}

839
websocket_client.go Normal file
View File

@@ -0,0 +1,839 @@
package main
import (
"context"
"encoding/json"
"fmt"
"log"
"net"
"net/http"
"os"
"os/exec"
"strings"
"time"
"nannyagentv2/internal/auth"
"nannyagentv2/internal/metrics"
"github.com/gorilla/websocket"
"github.com/sashabaranov/go-openai"
)
// Helper function for minimum of two integers
// WebSocketMessage represents a message sent over WebSocket
type WebSocketMessage struct {
Type string `json:"type"`
Data interface{} `json:"data"`
}
// InvestigationTask represents a task sent to the agent
type InvestigationTask struct {
TaskID string `json:"task_id"`
InvestigationID string `json:"investigation_id"`
AgentID string `json:"agent_id"`
DiagnosticPayload map[string]interface{} `json:"diagnostic_payload"`
EpisodeID string `json:"episode_id,omitempty"`
}
// TaskResult represents the result of a completed task
type TaskResult struct {
TaskID string `json:"task_id"`
Success bool `json:"success"`
CommandResults map[string]interface{} `json:"command_results,omitempty"`
Error string `json:"error,omitempty"`
}
// HeartbeatData represents heartbeat information
type HeartbeatData struct {
AgentID string `json:"agent_id"`
Timestamp time.Time `json:"timestamp"`
Version string `json:"version"`
}
// WebSocketClient handles WebSocket connection to Supabase backend
type WebSocketClient struct {
agent *LinuxDiagnosticAgent
conn *websocket.Conn
agentID string
authManager *auth.AuthManager
metricsCollector *metrics.Collector
supabaseURL string
token string
ctx context.Context
cancel context.CancelFunc
consecutiveFailures int // Track consecutive connection failures
}
// NewWebSocketClient creates a new WebSocket client
func NewWebSocketClient(agent *LinuxDiagnosticAgent, authManager *auth.AuthManager) *WebSocketClient {
// Get agent ID from authentication system
var agentID string
if authManager != nil {
if id, err := authManager.GetCurrentAgentID(); err == nil {
agentID = id
// Agent ID retrieved successfully
} else {
fmt.Printf("❌ Failed to get agent ID from auth manager: %v\n", err)
}
}
// Fallback to environment variable or generate one if auth fails
if agentID == "" {
agentID = os.Getenv("AGENT_ID")
if agentID == "" {
agentID = fmt.Sprintf("agent-%d", time.Now().Unix())
}
}
supabaseURL := os.Getenv("SUPABASE_PROJECT_URL")
if supabaseURL == "" {
log.Fatal("❌ SUPABASE_PROJECT_URL environment variable is required")
}
// Create metrics collector
metricsCollector := metrics.NewCollector("v2.0.0")
ctx, cancel := context.WithCancel(context.Background())
return &WebSocketClient{
agent: agent,
agentID: agentID,
authManager: authManager,
metricsCollector: metricsCollector,
supabaseURL: supabaseURL,
ctx: ctx,
cancel: cancel,
}
}
// Start starts the WebSocket connection and message handling
func (w *WebSocketClient) Start() error {
// Starting WebSocket client
if err := w.connect(); err != nil {
return fmt.Errorf("failed to establish WebSocket connection: %v", err)
}
// Start message reading loop
go w.handleMessages()
// Start heartbeat
go w.startHeartbeat()
// Start database polling for pending investigations
go w.pollPendingInvestigations()
// WebSocket client started
return nil
}
// Stop closes the WebSocket connection
func (c *WebSocketClient) Stop() {
c.cancel()
if c.conn != nil {
c.conn.Close()
}
}
// getAuthToken retrieves authentication token
func (c *WebSocketClient) getAuthToken() error {
if c.authManager == nil {
return fmt.Errorf("auth manager not available")
}
token, err := c.authManager.EnsureAuthenticated()
if err != nil {
return fmt.Errorf("authentication failed: %v", err)
}
c.token = token.AccessToken
return nil
}
// connect establishes WebSocket connection
func (c *WebSocketClient) connect() error {
// Get fresh auth token
if err := c.getAuthToken(); err != nil {
return fmt.Errorf("failed to get auth token: %v", err)
}
// Convert HTTP URL to WebSocket URL
wsURL := strings.Replace(c.supabaseURL, "https://", "wss://", 1)
wsURL = strings.Replace(wsURL, "http://", "ws://", 1)
wsURL += "/functions/v1/websocket-agent-handler"
// Connecting to WebSocket
// Set up headers
headers := http.Header{}
headers.Set("Authorization", "Bearer "+c.token)
// Connect
dialer := websocket.Dialer{
HandshakeTimeout: 10 * time.Second,
}
conn, resp, err := dialer.Dial(wsURL, headers)
if err != nil {
c.consecutiveFailures++
if c.consecutiveFailures >= 5 && resp != nil {
fmt.Printf("❌ WebSocket handshake failed with status: %d (failure #%d)\n", resp.StatusCode, c.consecutiveFailures)
}
return fmt.Errorf("websocket connection failed: %v", err)
}
c.conn = conn
// WebSocket client connected
return nil
}
// handleMessages processes incoming WebSocket messages
func (c *WebSocketClient) handleMessages() {
defer func() {
if c.conn != nil {
// Closing WebSocket connection
c.conn.Close()
}
}()
// Started WebSocket message listener
connectionStart := time.Now()
for {
select {
case <-c.ctx.Done():
// Only log context cancellation if there have been failures
if c.consecutiveFailures >= 5 {
fmt.Printf("📡 Context cancelled after %v, stopping message handler\n", time.Since(connectionStart))
}
return
default:
// Set read deadline to detect connection issues
c.conn.SetReadDeadline(time.Now().Add(90 * time.Second))
var message WebSocketMessage
readStart := time.Now()
err := c.conn.ReadJSON(&message)
readDuration := time.Since(readStart)
if err != nil {
connectionDuration := time.Since(connectionStart)
// Only log specific errors after failure threshold
if c.consecutiveFailures >= 5 {
if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
log.Printf("🔒 WebSocket closed normally after %v: %v", connectionDuration, err)
} else if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
log.Printf("💥 ABNORMAL CLOSE after %v (code 1006 = server-side timeout/kill): %v", connectionDuration, err)
log.Printf("🕒 Last read took %v, connection lived %v", readDuration, connectionDuration)
} else if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
log.Printf("⏰ READ TIMEOUT after %v: %v", connectionDuration, err)
} else {
log.Printf("❌ WebSocket error after %v: %v", connectionDuration, err)
}
}
// Track consecutive failures for diagnostic threshold
c.consecutiveFailures++
// Only show diagnostics after multiple failures
if c.consecutiveFailures >= 5 {
log.Printf("🔍 DIAGNOSTIC - Connection failed #%d after %v", c.consecutiveFailures, connectionDuration)
}
// Attempt reconnection instead of returning immediately
go c.attemptReconnection()
return
}
// Received WebSocket message successfully - reset failure counter
c.consecutiveFailures = 0
switch message.Type {
case "connection_ack":
// Connection acknowledged
case "heartbeat_ack":
// Heartbeat acknowledged
case "investigation_task":
// Received investigation task - processing
go c.handleInvestigationTask(message.Data)
case "task_result_ack":
// Task result acknowledged
default:
log.Printf("⚠️ Unknown message type: %s", message.Type)
}
}
}
}
// handleInvestigationTask processes investigation tasks from the backend
func (c *WebSocketClient) handleInvestigationTask(data interface{}) {
// Parse task data
taskBytes, err := json.Marshal(data)
if err != nil {
log.Printf("❌ Error marshaling task data: %v", err)
return
}
var task InvestigationTask
err = json.Unmarshal(taskBytes, &task)
if err != nil {
log.Printf("❌ Error unmarshaling investigation task: %v", err)
return
}
// Processing investigation task
// Execute diagnostic commands
results, err := c.executeDiagnosticCommands(task.DiagnosticPayload)
// Prepare task result
taskResult := TaskResult{
TaskID: task.TaskID,
Success: err == nil,
}
if err != nil {
taskResult.Error = err.Error()
fmt.Printf("❌ Task execution failed: %v\n", err)
} else {
taskResult.CommandResults = results
// Task executed successfully
}
// Send result back
c.sendTaskResult(taskResult)
}
// executeDiagnosticCommands executes the commands from a diagnostic response
func (c *WebSocketClient) executeDiagnosticCommands(diagnosticPayload map[string]interface{}) (map[string]interface{}, error) {
results := map[string]interface{}{
"agent_id": c.agentID,
"execution_time": time.Now().UTC().Format(time.RFC3339),
"command_results": []map[string]interface{}{},
}
// Extract commands from diagnostic payload
commands, ok := diagnosticPayload["commands"].([]interface{})
if !ok {
return nil, fmt.Errorf("no commands found in diagnostic payload")
}
var commandResults []map[string]interface{}
for _, cmd := range commands {
cmdMap, ok := cmd.(map[string]interface{})
if !ok {
continue
}
id, _ := cmdMap["id"].(string)
command, _ := cmdMap["command"].(string)
description, _ := cmdMap["description"].(string)
if command == "" {
continue
}
// Executing command
// Execute the command
output, exitCode, err := c.executeCommand(command)
result := map[string]interface{}{
"id": id,
"command": command,
"description": description,
"output": output,
"exit_code": exitCode,
"success": err == nil && exitCode == 0,
}
if err != nil {
result["error"] = err.Error()
fmt.Printf("❌ Command [%s] failed: %v (exit code: %d)\n", id, err, exitCode)
}
commandResults = append(commandResults, result)
}
results["command_results"] = commandResults
results["total_commands"] = len(commandResults)
results["successful_commands"] = c.countSuccessfulCommands(commandResults)
// Execute eBPF programs if present
ebpfPrograms, hasEBPF := diagnosticPayload["ebpf_programs"].([]interface{})
if hasEBPF && len(ebpfPrograms) > 0 {
ebpfResults := c.executeEBPFPrograms(ebpfPrograms)
results["ebpf_results"] = ebpfResults
results["total_ebpf_programs"] = len(ebpfPrograms)
}
return results, nil
}
// executeEBPFPrograms executes eBPF monitoring programs using the real eBPF manager
func (c *WebSocketClient) executeEBPFPrograms(ebpfPrograms []interface{}) []map[string]interface{} {
var ebpfRequests []EBPFRequest
// Convert interface{} to EBPFRequest structs
for _, prog := range ebpfPrograms {
progMap, ok := prog.(map[string]interface{})
if !ok {
continue
}
name, _ := progMap["name"].(string)
progType, _ := progMap["type"].(string)
target, _ := progMap["target"].(string)
duration, _ := progMap["duration"].(float64)
description, _ := progMap["description"].(string)
if name == "" || progType == "" || target == "" {
continue
}
ebpfRequests = append(ebpfRequests, EBPFRequest{
Name: name,
Type: progType,
Target: target,
Duration: int(duration),
Description: description,
})
}
// Execute eBPF programs using the agent's eBPF execution logic
return c.agent.executeEBPFPrograms(ebpfRequests)
}
// executeCommandsFromPayload executes commands from a payload and returns results
func (c *WebSocketClient) executeCommandsFromPayload(commands []interface{}) []map[string]interface{} {
var commandResults []map[string]interface{}
for _, cmd := range commands {
cmdMap, ok := cmd.(map[string]interface{})
if !ok {
continue
}
id, _ := cmdMap["id"].(string)
command, _ := cmdMap["command"].(string)
description, _ := cmdMap["description"].(string)
if command == "" {
continue
}
// Execute the command
output, exitCode, err := c.executeCommand(command)
result := map[string]interface{}{
"id": id,
"command": command,
"description": description,
"output": output,
"exit_code": exitCode,
"success": err == nil && exitCode == 0,
}
if err != nil {
result["error"] = err.Error()
fmt.Printf("❌ Command [%s] failed: %v (exit code: %d)\n", id, err, exitCode)
}
commandResults = append(commandResults, result)
}
return commandResults
}
// executeCommand executes a shell command and returns output, exit code, and error
func (c *WebSocketClient) executeCommand(command string) (string, int, error) {
// Parse command into parts
parts := strings.Fields(command)
if len(parts) == 0 {
return "", -1, fmt.Errorf("empty command")
}
// Create command with timeout
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, parts[0], parts[1:]...)
cmd.Env = os.Environ()
output, err := cmd.CombinedOutput()
exitCode := 0
if err != nil {
if exitError, ok := err.(*exec.ExitError); ok {
exitCode = exitError.ExitCode()
} else {
exitCode = -1
}
}
return string(output), exitCode, err
}
// countSuccessfulCommands counts the number of successful commands
func (c *WebSocketClient) countSuccessfulCommands(results []map[string]interface{}) int {
count := 0
for _, result := range results {
if success, ok := result["success"].(bool); ok && success {
count++
}
}
return count
}
// sendTaskResult sends a task result back to the backend
func (c *WebSocketClient) sendTaskResult(result TaskResult) {
message := WebSocketMessage{
Type: "task_result",
Data: result,
}
err := c.conn.WriteJSON(message)
if err != nil {
log.Printf("❌ Error sending task result: %v", err)
}
}
// startHeartbeat sends periodic heartbeat messages
func (c *WebSocketClient) startHeartbeat() {
ticker := time.NewTicker(30 * time.Second) // Heartbeat every 30 seconds
defer ticker.Stop()
// Starting heartbeat
for {
select {
case <-c.ctx.Done():
fmt.Printf("💓 Heartbeat stopped due to context cancellation\n")
return
case <-ticker.C:
// Sending heartbeat
heartbeat := WebSocketMessage{
Type: "heartbeat",
Data: HeartbeatData{
AgentID: c.agentID,
Timestamp: time.Now(),
Version: "v2.0.0",
},
}
err := c.conn.WriteJSON(heartbeat)
if err != nil {
log.Printf("❌ Error sending heartbeat: %v", err)
fmt.Printf("💓 Heartbeat failed, connection likely dead\n")
return
}
// Heartbeat sent
}
}
}
// pollPendingInvestigations polls the database for pending investigations
func (c *WebSocketClient) pollPendingInvestigations() {
// Starting database polling
ticker := time.NewTicker(5 * time.Second) // Poll every 5 seconds
defer ticker.Stop()
for {
select {
case <-c.ctx.Done():
return
case <-ticker.C:
c.checkForPendingInvestigations()
}
}
}
// checkForPendingInvestigations checks the database for new pending investigations via proxy
func (c *WebSocketClient) checkForPendingInvestigations() {
// Use Edge Function proxy instead of direct database access
url := fmt.Sprintf("%s/functions/v1/agent-database-proxy/pending-investigations", c.supabaseURL)
// Poll database for pending investigations
req, err := http.NewRequest("GET", url, nil)
if err != nil {
// Request creation failed
return
}
// Only JWT token needed for proxy - no API keys exposed
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.token))
req.Header.Set("Accept", "application/json")
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
// Database request failed
return
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
return
}
var investigations []PendingInvestigation
err = json.NewDecoder(resp.Body).Decode(&investigations)
if err != nil {
// Response decode failed
return
}
for _, investigation := range investigations {
go c.handlePendingInvestigation(investigation)
}
}
// handlePendingInvestigation processes a pending investigation from database polling
func (c *WebSocketClient) handlePendingInvestigation(investigation PendingInvestigation) {
// Processing pending investigation
// Mark as executing
err := c.updateInvestigationStatus(investigation.ID, "executing", nil, nil)
if err != nil {
return
}
// Execute diagnostic commands
results, err := c.executeDiagnosticCommands(investigation.DiagnosticPayload)
// Prepare the base results map we'll send to DB
resultsForDB := map[string]interface{}{
"agent_id": c.agentID,
"execution_time": time.Now().UTC().Format(time.RFC3339),
"command_results": results,
}
// If command execution failed, mark investigation as failed
if err != nil {
errorMsg := err.Error()
// Include partial results when possible
if results != nil {
resultsForDB["command_results"] = results
}
c.updateInvestigationStatus(investigation.ID, "failed", resultsForDB, &errorMsg)
// Investigation failed
return
}
// Try to continue the TensorZero conversation by sending command results back
// Build messages: assistant = diagnostic payload, user = command results
diagJSON, _ := json.Marshal(investigation.DiagnosticPayload)
commandsJSON, _ := json.MarshalIndent(results, "", " ")
messages := []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleAssistant,
Content: string(diagJSON),
},
{
Role: openai.ChatMessageRoleUser,
Content: string(commandsJSON),
},
}
// Use the episode ID from the investigation to maintain conversation continuity
episodeID := ""
if investigation.EpisodeID != nil {
episodeID = *investigation.EpisodeID
}
// Continue conversation until resolution (same as agent)
var finalAIContent string
for {
tzResp, tzErr := c.agent.sendRequestWithEpisode(messages, episodeID)
if tzErr != nil {
fmt.Printf("⚠️ TensorZero continuation failed: %v\n", tzErr)
// Fall back to marking completed with command results only
c.updateInvestigationStatus(investigation.ID, "completed", resultsForDB, nil)
return
}
if len(tzResp.Choices) == 0 {
fmt.Printf("⚠️ No choices in TensorZero response\n")
c.updateInvestigationStatus(investigation.ID, "completed", resultsForDB, nil)
return
}
aiContent := tzResp.Choices[0].Message.Content
if len(aiContent) > 300 {
// AI response received successfully
} else {
fmt.Printf("🤖 AI Response: %s\n", aiContent)
}
// Check if this is a resolution response (final)
var resolutionResp struct {
ResponseType string `json:"response_type"`
RootCause string `json:"root_cause"`
ResolutionPlan string `json:"resolution_plan"`
Confidence string `json:"confidence"`
}
fmt.Printf("🔍 Analyzing AI response type...\n")
if err := json.Unmarshal([]byte(aiContent), &resolutionResp); err == nil && resolutionResp.ResponseType == "resolution" {
// This is the final resolution - show summary and complete
fmt.Printf("\n=== DIAGNOSIS COMPLETE ===\n")
fmt.Printf("Root Cause: %s\n", resolutionResp.RootCause)
fmt.Printf("Resolution Plan: %s\n", resolutionResp.ResolutionPlan)
fmt.Printf("Confidence: %s\n", resolutionResp.Confidence)
finalAIContent = aiContent
break
}
// Check if this is another diagnostic response requiring more commands
var diagnosticResp struct {
ResponseType string `json:"response_type"`
Commands []interface{} `json:"commands"`
EBPFPrograms []interface{} `json:"ebpf_programs"`
}
if err := json.Unmarshal([]byte(aiContent), &diagnosticResp); err == nil && diagnosticResp.ResponseType == "diagnostic" {
fmt.Printf("🔄 AI requested additional diagnostics, executing...\n")
// Execute additional commands if any
additionalResults := map[string]interface{}{
"command_results": []map[string]interface{}{},
}
if len(diagnosticResp.Commands) > 0 {
fmt.Printf("🔧 Executing %d additional diagnostic commands...\n", len(diagnosticResp.Commands))
commandResults := c.executeCommandsFromPayload(diagnosticResp.Commands)
additionalResults["command_results"] = commandResults
}
// Execute additional eBPF programs if any
if len(diagnosticResp.EBPFPrograms) > 0 {
ebpfResults := c.executeEBPFPrograms(diagnosticResp.EBPFPrograms)
additionalResults["ebpf_results"] = ebpfResults
}
// Add AI response and additional results to conversation
messages = append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleAssistant,
Content: aiContent,
})
additionalResultsJSON, _ := json.MarshalIndent(additionalResults, "", " ")
messages = append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleUser,
Content: string(additionalResultsJSON),
})
continue
}
// If neither resolution nor diagnostic, treat as final response
fmt.Printf("⚠️ Unknown response type - treating as final response\n")
finalAIContent = aiContent
break
}
// Attach final AI response to results for DB and mark as completed_with_analysis
resultsForDB["ai_response"] = finalAIContent
c.updateInvestigationStatus(investigation.ID, "completed_with_analysis", resultsForDB, nil)
}
// updateInvestigationStatus updates the status of a pending investigation
func (c *WebSocketClient) updateInvestigationStatus(id, status string, results map[string]interface{}, errorMsg *string) error {
updateData := map[string]interface{}{
"status": status,
}
if status == "executing" {
updateData["started_at"] = time.Now().UTC().Format(time.RFC3339)
} else if status == "completed" {
updateData["completed_at"] = time.Now().UTC().Format(time.RFC3339)
if results != nil {
updateData["command_results"] = results
}
} else if status == "failed" && errorMsg != nil {
updateData["error_message"] = *errorMsg
updateData["completed_at"] = time.Now().UTC().Format(time.RFC3339)
}
jsonData, err := json.Marshal(updateData)
if err != nil {
return fmt.Errorf("failed to marshal update data: %v", err)
}
url := fmt.Sprintf("%s/functions/v1/agent-database-proxy/pending-investigations/%s", c.supabaseURL, id)
req, err := http.NewRequest("PATCH", url, strings.NewReader(string(jsonData)))
if err != nil {
return fmt.Errorf("failed to create request: %v", err)
}
// Only JWT token needed for proxy - no API keys exposed
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.token))
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("failed to update investigation: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 && resp.StatusCode != 204 {
return fmt.Errorf("supabase update error: %d", resp.StatusCode)
}
return nil
}
// attemptReconnection attempts to reconnect the WebSocket with backoff
func (c *WebSocketClient) attemptReconnection() {
backoffDurations := []time.Duration{
2 * time.Second,
5 * time.Second,
10 * time.Second,
20 * time.Second,
30 * time.Second,
}
for i, backoff := range backoffDurations {
select {
case <-c.ctx.Done():
return
default:
c.consecutiveFailures++
// Only show messages after 5 consecutive failures
if c.consecutiveFailures >= 5 {
log.Printf("🔄 Attempting WebSocket reconnection (attempt %d/%d) - %d consecutive failures", i+1, len(backoffDurations), c.consecutiveFailures)
}
time.Sleep(backoff)
if err := c.connect(); err != nil {
if c.consecutiveFailures >= 5 {
log.Printf("❌ Reconnection attempt %d failed: %v", i+1, err)
}
continue
}
// Successfully reconnected - reset failure counter
if c.consecutiveFailures >= 5 {
log.Printf("✅ WebSocket reconnected successfully after %d failures", c.consecutiveFailures)
}
c.consecutiveFailures = 0
go c.handleMessages() // Restart message handling
return
}
}
log.Printf("❌ Failed to reconnect after %d attempts, giving up", len(backoffDurations))
}