Compare commits
4 Commits
f69e1dbc66
...
8328f8d5b3
| Author | SHA256 | Date | |
|---|---|---|---|
|
|
8328f8d5b3 | ||
|
|
8832450a1f | ||
|
|
0a8b2dc202 | ||
|
|
6fd403cb5f |
5
.gitignore
vendored
5
.gitignore
vendored
@@ -23,6 +23,7 @@ go.work
|
||||
go.work.sum
|
||||
|
||||
# env file
|
||||
.env
|
||||
.env*
|
||||
nannyagent*
|
||||
nanny-agent*
|
||||
nanny-agent*
|
||||
.vscode
|
||||
|
||||
256
agent.go
256
agent.go
@@ -55,10 +55,11 @@ type LinuxDiagnosticAgent struct {
|
||||
|
||||
// NewLinuxDiagnosticAgent creates a new diagnostic agent
|
||||
func NewLinuxDiagnosticAgent() *LinuxDiagnosticAgent {
|
||||
endpoint := os.Getenv("NANNYAPI_ENDPOINT")
|
||||
if endpoint == "" {
|
||||
// Default endpoint - OpenAI SDK will append /chat/completions automatically
|
||||
endpoint = "http://tensorzero.netcup.internal:3000/openai/v1"
|
||||
// Get Supabase project URL for TensorZero proxy
|
||||
supabaseURL := os.Getenv("SUPABASE_PROJECT_URL")
|
||||
if supabaseURL == "" {
|
||||
fmt.Printf("Warning: SUPABASE_PROJECT_URL not set, TensorZero integration will not work\n")
|
||||
supabaseURL = "https://gpqzsricripnvbrpsyws.supabase.co" // fallback
|
||||
}
|
||||
|
||||
model := os.Getenv("NANNYAPI_MODEL")
|
||||
@@ -67,14 +68,9 @@ func NewLinuxDiagnosticAgent() *LinuxDiagnosticAgent {
|
||||
fmt.Printf("Warning: Using default model '%s'. Set NANNYAPI_MODEL environment variable for your specific function.\n", model)
|
||||
}
|
||||
|
||||
// Create OpenAI client with custom base URL
|
||||
// Note: The OpenAI SDK automatically appends "/chat/completions" to the base URL
|
||||
config := openai.DefaultConfig("")
|
||||
config.BaseURL = endpoint
|
||||
client := openai.NewClientWithConfig(config)
|
||||
|
||||
// Note: We don't use the OpenAI client anymore, we use direct HTTP to Supabase proxy
|
||||
agent := &LinuxDiagnosticAgent{
|
||||
client: client,
|
||||
client: nil, // Not used anymore
|
||||
model: model,
|
||||
executor: NewCommandExecutor(10 * time.Second), // 10 second timeout for commands
|
||||
}
|
||||
@@ -106,7 +102,7 @@ func (a *LinuxDiagnosticAgent) DiagnoseIssue(issue string) error {
|
||||
|
||||
for {
|
||||
// Send request to TensorZero API via OpenAI SDK
|
||||
response, err := a.sendRequest(messages)
|
||||
response, err := a.sendRequestWithEpisode(messages, a.episodeID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
@@ -119,34 +115,60 @@ func (a *LinuxDiagnosticAgent) DiagnoseIssue(issue string) error {
|
||||
fmt.Printf("\nAI Response:\n%s\n", content)
|
||||
|
||||
// Parse the response to determine next action
|
||||
var diagnosticResp DiagnosticResponse
|
||||
var diagnosticResp EBPFEnhancedDiagnosticResponse
|
||||
var resolutionResp ResolutionResponse
|
||||
|
||||
// Try to parse as diagnostic response first
|
||||
// Try to parse as diagnostic response first (with eBPF support)
|
||||
if err := json.Unmarshal([]byte(content), &diagnosticResp); err == nil && diagnosticResp.ResponseType == "diagnostic" {
|
||||
// Handle diagnostic phase
|
||||
fmt.Printf("\nReasoning: %s\n", diagnosticResp.Reasoning)
|
||||
|
||||
if len(diagnosticResp.Commands) == 0 {
|
||||
fmt.Println("No commands to execute in diagnostic phase")
|
||||
break
|
||||
}
|
||||
|
||||
// Execute commands and collect results
|
||||
commandResults := make([]CommandResult, 0, len(diagnosticResp.Commands))
|
||||
for _, cmd := range diagnosticResp.Commands {
|
||||
fmt.Printf("\nExecuting command '%s': %s\n", cmd.ID, cmd.Command)
|
||||
result := a.executor.Execute(cmd)
|
||||
commandResults = append(commandResults, result)
|
||||
if len(diagnosticResp.Commands) > 0 {
|
||||
fmt.Printf("🔧 Executing diagnostic commands...\n")
|
||||
for _, cmd := range diagnosticResp.Commands {
|
||||
result := a.executor.Execute(cmd)
|
||||
commandResults = append(commandResults, result)
|
||||
|
||||
fmt.Printf("Output:\n%s\n", result.Output)
|
||||
if result.Error != "" {
|
||||
fmt.Printf("Error: %s\n", result.Error)
|
||||
if result.ExitCode != 0 {
|
||||
fmt.Printf("❌ Command '%s' failed with exit code %d\n", cmd.ID, result.ExitCode)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare command results as user message
|
||||
resultsJSON, err := json.MarshalIndent(commandResults, "", " ")
|
||||
// Execute eBPF programs if present
|
||||
var ebpfResults []map[string]interface{}
|
||||
if len(diagnosticResp.EBPFPrograms) > 0 {
|
||||
ebpfResults = a.executeEBPFPrograms(diagnosticResp.EBPFPrograms)
|
||||
}
|
||||
|
||||
// Prepare combined results as user message
|
||||
allResults := map[string]interface{}{
|
||||
"command_results": commandResults,
|
||||
"executed_commands": len(commandResults),
|
||||
}
|
||||
|
||||
// Include eBPF results if any were executed
|
||||
if len(ebpfResults) > 0 {
|
||||
allResults["ebpf_results"] = ebpfResults
|
||||
allResults["executed_ebpf_programs"] = len(ebpfResults)
|
||||
|
||||
// Extract evidence summary for TensorZero
|
||||
evidenceSummary := make([]string, 0)
|
||||
for _, result := range ebpfResults {
|
||||
name := result["name"]
|
||||
eventCount := result["data_points"]
|
||||
description := result["description"]
|
||||
status := result["status"]
|
||||
|
||||
summaryStr := fmt.Sprintf("%s: %v events (%s) - %s", name, eventCount, status, description)
|
||||
evidenceSummary = append(evidenceSummary, summaryStr)
|
||||
}
|
||||
allResults["ebpf_evidence_summary"] = evidenceSummary
|
||||
}
|
||||
|
||||
resultsJSON, err := json.MarshalIndent(allResults, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal command results: %w", err)
|
||||
}
|
||||
@@ -182,6 +204,119 @@ func (a *LinuxDiagnosticAgent) DiagnoseIssue(issue string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// executeEBPFPrograms executes REAL eBPF monitoring programs using the actual eBPF manager
|
||||
func (a *LinuxDiagnosticAgent) executeEBPFPrograms(ebpfPrograms []EBPFRequest) []map[string]interface{} {
|
||||
var results []map[string]interface{}
|
||||
|
||||
if a.ebpfManager == nil {
|
||||
fmt.Printf("❌ eBPF manager not initialized\n")
|
||||
return results
|
||||
}
|
||||
|
||||
for _, prog := range ebpfPrograms {
|
||||
// eBPF program starting - only show in debug mode
|
||||
|
||||
// Actually start the eBPF program using the real manager
|
||||
programID, err := a.ebpfManager.StartEBPFProgram(prog)
|
||||
if err != nil {
|
||||
fmt.Printf("❌ Failed to start eBPF program [%s]: %v\n", prog.Name, err)
|
||||
result := map[string]interface{}{
|
||||
"name": prog.Name,
|
||||
"type": prog.Type,
|
||||
"target": prog.Target,
|
||||
"duration": int(prog.Duration),
|
||||
"description": prog.Description,
|
||||
"status": "failed",
|
||||
"error": err.Error(),
|
||||
"success": false,
|
||||
}
|
||||
results = append(results, result)
|
||||
continue
|
||||
}
|
||||
|
||||
// Let the eBPF program run for the specified duration
|
||||
time.Sleep(time.Duration(prog.Duration) * time.Second)
|
||||
|
||||
// Give the collectEvents goroutine a moment to finish and store results
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Use a channel to implement timeout for GetProgramResults
|
||||
type resultPair struct {
|
||||
trace *EBPFTrace
|
||||
err error
|
||||
}
|
||||
resultChan := make(chan resultPair, 1)
|
||||
|
||||
go func() {
|
||||
trace, err := a.ebpfManager.GetProgramResults(programID)
|
||||
resultChan <- resultPair{trace, err}
|
||||
}()
|
||||
|
||||
var trace *EBPFTrace
|
||||
var resultErr error
|
||||
|
||||
select {
|
||||
case result := <-resultChan:
|
||||
trace = result.trace
|
||||
resultErr = result.err
|
||||
case <-time.After(3 * time.Second):
|
||||
resultErr = fmt.Errorf("timeout getting results after 3 seconds")
|
||||
}
|
||||
|
||||
// Try to stop the program (may already be stopped by collectEvents)
|
||||
stopErr := a.ebpfManager.StopProgram(programID)
|
||||
if stopErr != nil {
|
||||
// Only show warning in debug mode - this is normal for completed programs
|
||||
}
|
||||
|
||||
if resultErr != nil {
|
||||
fmt.Printf("❌ Failed to get results for eBPF program [%s]: %v\n", prog.Name, resultErr)
|
||||
result := map[string]interface{}{
|
||||
"name": prog.Name,
|
||||
"type": prog.Type,
|
||||
"target": prog.Target,
|
||||
"duration": int(prog.Duration),
|
||||
"description": prog.Description,
|
||||
"status": "collection_failed",
|
||||
"error": resultErr.Error(),
|
||||
"success": false,
|
||||
}
|
||||
results = append(results, result)
|
||||
continue
|
||||
} // Process the real eBPF trace data
|
||||
result := map[string]interface{}{
|
||||
"name": prog.Name,
|
||||
"type": prog.Type,
|
||||
"target": prog.Target,
|
||||
"duration": int(prog.Duration),
|
||||
"description": prog.Description,
|
||||
"status": "completed",
|
||||
"success": true,
|
||||
}
|
||||
|
||||
// Extract real data from the trace
|
||||
if trace != nil {
|
||||
result["trace_id"] = trace.TraceID
|
||||
result["data_points"] = trace.EventCount
|
||||
result["events"] = trace.Events
|
||||
result["summary"] = trace.Summary
|
||||
result["process_list"] = trace.ProcessList
|
||||
result["start_time"] = trace.StartTime.Format(time.RFC3339)
|
||||
result["end_time"] = trace.EndTime.Format(time.RFC3339)
|
||||
result["actual_duration"] = trace.EndTime.Sub(trace.StartTime).Seconds()
|
||||
|
||||
} else {
|
||||
result["data_points"] = 0
|
||||
result["error"] = "No trace data returned"
|
||||
fmt.Printf("⚠️ eBPF program [%s] completed but returned no trace data\n", prog.Name)
|
||||
}
|
||||
|
||||
results = append(results, result)
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
// TensorZeroRequest represents a request structure compatible with TensorZero's episode_id
|
||||
type TensorZeroRequest struct {
|
||||
Model string `json:"model"`
|
||||
@@ -195,8 +330,13 @@ type TensorZeroResponse struct {
|
||||
EpisodeID string `json:"episode_id"`
|
||||
}
|
||||
|
||||
// sendRequest sends a request to the TensorZero API with tensorzero::episode_id support
|
||||
// sendRequest sends a request to the TensorZero API via Supabase proxy with JWT authentication
|
||||
func (a *LinuxDiagnosticAgent) sendRequest(messages []openai.ChatCompletionMessage) (*openai.ChatCompletionResponse, error) {
|
||||
return a.sendRequestWithEpisode(messages, "")
|
||||
}
|
||||
|
||||
// sendRequestWithEpisode sends a request with a specific episode ID
|
||||
func (a *LinuxDiagnosticAgent) sendRequestWithEpisode(messages []openai.ChatCompletionMessage, episodeID string) (*openai.ChatCompletionResponse, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
@@ -206,9 +346,12 @@ func (a *LinuxDiagnosticAgent) sendRequest(messages []openai.ChatCompletionMessa
|
||||
Messages: messages,
|
||||
}
|
||||
|
||||
// Include tensorzero::episode_id for conversation continuity (if we have one)
|
||||
// Include tensorzero::episode_id for conversation continuity
|
||||
// Use agent's existing episode ID if available, otherwise use provided one
|
||||
if a.episodeID != "" {
|
||||
tzRequest.EpisodeID = a.episodeID
|
||||
} else if episodeID != "" {
|
||||
tzRequest.EpisodeID = episodeID
|
||||
}
|
||||
|
||||
fmt.Printf("Debug: Sending request to model: %s", a.model)
|
||||
@@ -223,17 +366,14 @@ func (a *LinuxDiagnosticAgent) sendRequest(messages []openai.ChatCompletionMessa
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
// Create HTTP request
|
||||
endpoint := os.Getenv("NANNYAPI_ENDPOINT")
|
||||
if endpoint == "" {
|
||||
endpoint = "http://tensorzero.netcup.internal:3000/openai/v1"
|
||||
// Get Supabase project URL and build TensorZero proxy endpoint
|
||||
supabaseURL := os.Getenv("SUPABASE_PROJECT_URL")
|
||||
if supabaseURL == "" {
|
||||
supabaseURL = "https://gpqzsricripnvbrpsyws.supabase.co"
|
||||
}
|
||||
|
||||
// Ensure the endpoint ends with /chat/completions
|
||||
if endpoint[len(endpoint)-1] != '/' {
|
||||
endpoint += "/"
|
||||
}
|
||||
endpoint += "chat/completions"
|
||||
// Build Supabase function URL with OpenAI v1 compatible path
|
||||
endpoint := supabaseURL + "/functions/v1/tensorzero-proxy/openai/v1/chat/completions"
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", endpoint, bytes.NewBuffer(requestBody))
|
||||
if err != nil {
|
||||
@@ -242,6 +382,14 @@ func (a *LinuxDiagnosticAgent) sendRequest(messages []openai.ChatCompletionMessa
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// Add JWT authentication header
|
||||
accessToken, err := a.getAccessToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get access token: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
|
||||
// Make the request
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
@@ -257,7 +405,7 @@ func (a *LinuxDiagnosticAgent) sendRequest(messages []openai.ChatCompletionMessa
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
|
||||
return nil, fmt.Errorf("TensorZero API request failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Parse TensorZero response
|
||||
@@ -274,3 +422,31 @@ func (a *LinuxDiagnosticAgent) sendRequest(messages []openai.ChatCompletionMessa
|
||||
|
||||
return &tzResponse.ChatCompletionResponse, nil
|
||||
}
|
||||
|
||||
// getAccessToken retrieves the current access token for authentication
|
||||
func (a *LinuxDiagnosticAgent) getAccessToken() (string, error) {
|
||||
// Read token from the standard token file location
|
||||
tokenPath := os.Getenv("TOKEN_PATH")
|
||||
if tokenPath == "" {
|
||||
tokenPath = "/var/lib/nannyagent/token.json"
|
||||
}
|
||||
|
||||
tokenData, err := os.ReadFile(tokenPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read token file: %w", err)
|
||||
}
|
||||
|
||||
var tokenInfo struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(tokenData, &tokenInfo); err != nil {
|
||||
return "", fmt.Errorf("failed to parse token file: %w", err)
|
||||
}
|
||||
|
||||
if tokenInfo.AccessToken == "" {
|
||||
return "", fmt.Errorf("access token is empty")
|
||||
}
|
||||
|
||||
return tokenInfo.AccessToken, nil
|
||||
}
|
||||
|
||||
14
go.mod
14
go.mod
@@ -6,7 +6,19 @@ toolchain go1.24.2
|
||||
|
||||
require (
|
||||
github.com/cilium/ebpf v0.19.0
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/sashabaranov/go-openai v1.32.0
|
||||
github.com/shirou/gopsutil/v3 v3.24.5
|
||||
)
|
||||
|
||||
require golang.org/x/sys v0.31.0 // indirect
|
||||
require (
|
||||
github.com/go-ole/go-ole v1.2.6 // indirect
|
||||
github.com/gorilla/websocket v1.5.3 // indirect
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
|
||||
github.com/shoenig/go-m1cpu v0.1.6 // indirect
|
||||
github.com/tklauser/go-sysconf v0.3.12 // indirect
|
||||
github.com/tklauser/numcpus v0.6.1 // indirect
|
||||
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
||||
golang.org/x/sys v0.31.0 // indirect
|
||||
)
|
||||
|
||||
36
go.sum
36
go.sum
@@ -1,9 +1,18 @@
|
||||
github.com/cilium/ebpf v0.19.0 h1:Ro/rE64RmFBeA9FGjcTc+KmCeY6jXmryu6FfnzPRIao=
|
||||
github.com/cilium/ebpf v0.19.0/go.mod h1:fLCgMo3l8tZmAdM3B2XqdFzXBpwkcSTroaVqN08OWVY=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
|
||||
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
||||
github.com/go-quicktest/qt v1.101.1-0.20240301121107-c6c8733fa1e6 h1:teYtXy9B7y5lHTp8V9KPxpYRAVA7dozigQcMiBust1s=
|
||||
github.com/go-quicktest/qt v1.101.1-0.20240301121107-c6c8733fa1e6/go.mod h1:p4lGIVX+8Wa6ZPNDvqcxq36XpUDLh42FLetFU7odllI=
|
||||
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
||||
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
||||
github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA=
|
||||
github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
|
||||
github.com/jsimonetti/rtnetlink/v2 v2.0.1 h1:xda7qaHDSVOsADNouv7ukSuicKZO7GgVUCXxpaIEIlM=
|
||||
@@ -12,17 +21,44 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
|
||||
github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g=
|
||||
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
|
||||
github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U=
|
||||
github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
|
||||
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
|
||||
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
|
||||
github.com/sashabaranov/go-openai v1.32.0 h1:Yk3iE9moX3RBXxrof3OBtUBrE7qZR0zF9ebsoO4zVzI=
|
||||
github.com/sashabaranov/go-openai v1.32.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
|
||||
github.com/shirou/gopsutil/v3 v3.24.5 h1:i0t8kL+kQTvpAYToeuiVk3TgDeKOFioZO3Ztz/iZ9pI=
|
||||
github.com/shirou/gopsutil/v3 v3.24.5/go.mod h1:bsoOS1aStSs9ErQ1WWfxllSeS1K5D+U30r2NfcubMVk=
|
||||
github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM=
|
||||
github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ=
|
||||
github.com/shoenig/test v0.6.4 h1:kVTaSd7WLz5WZ2IaoM0RSzRsUD+m8wRR+5qvntpn4LU=
|
||||
github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnjqq0k=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
|
||||
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
|
||||
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
|
||||
github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY=
|
||||
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
|
||||
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
||||
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
|
||||
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
|
||||
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
|
||||
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
509
internal/auth/auth.go
Normal file
509
internal/auth/auth.go
Normal file
@@ -0,0 +1,509 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"nannyagentv2/internal/config"
|
||||
"nannyagentv2/internal/types"
|
||||
)
|
||||
|
||||
const (
|
||||
// Token storage location (secure directory)
|
||||
TokenStorageDir = "/var/lib/nannyagent"
|
||||
TokenStorageFile = ".agent_token.json"
|
||||
RefreshTokenFile = ".refresh_token"
|
||||
|
||||
// Polling configuration
|
||||
MaxPollAttempts = 60 // 5 minutes (60 * 5 seconds)
|
||||
PollInterval = 5 * time.Second
|
||||
)
|
||||
|
||||
// AuthManager handles all authentication-related operations
|
||||
type AuthManager struct {
|
||||
config *config.Config
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewAuthManager creates a new authentication manager
|
||||
func NewAuthManager(cfg *config.Config) *AuthManager {
|
||||
return &AuthManager{
|
||||
config: cfg,
|
||||
client: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// EnsureTokenStorageDir creates the token storage directory if it doesn't exist
|
||||
func (am *AuthManager) EnsureTokenStorageDir() error {
|
||||
// Check if running as root
|
||||
if os.Geteuid() != 0 {
|
||||
return fmt.Errorf("must run as root to create secure token storage directory")
|
||||
}
|
||||
|
||||
// Create directory with restricted permissions (0700 - only root can access)
|
||||
if err := os.MkdirAll(TokenStorageDir, 0700); err != nil {
|
||||
return fmt.Errorf("failed to create token storage directory: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartDeviceAuthorization initiates the OAuth device authorization flow
|
||||
func (am *AuthManager) StartDeviceAuthorization() (*types.DeviceAuthResponse, error) {
|
||||
payload := map[string]interface{}{
|
||||
"client_id": "nannyagent-cli",
|
||||
"scope": []string{"agent:register"},
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal payload: %w", err)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/device/authorize", am.config.DeviceAuthURL)
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := am.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to start device authorization: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("device authorization failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var deviceResp types.DeviceAuthResponse
|
||||
if err := json.Unmarshal(body, &deviceResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
return &deviceResp, nil
|
||||
}
|
||||
|
||||
// PollForToken polls the token endpoint until authorization is complete
|
||||
func (am *AuthManager) PollForToken(deviceCode string) (*types.TokenResponse, error) {
|
||||
fmt.Println("⏳ Waiting for user authorization...")
|
||||
|
||||
for attempts := 0; attempts < MaxPollAttempts; attempts++ {
|
||||
tokenReq := types.TokenRequest{
|
||||
GrantType: "urn:ietf:params:oauth:grant-type:device_code",
|
||||
DeviceCode: deviceCode,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(tokenReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal token request: %w", err)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/token", am.config.DeviceAuthURL)
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create token request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := am.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to poll for token: %w", err)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read token response: %w", err)
|
||||
}
|
||||
|
||||
var tokenResp types.TokenResponse
|
||||
if err := json.Unmarshal(body, &tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token response: %w", err)
|
||||
}
|
||||
|
||||
if tokenResp.Error != "" {
|
||||
if tokenResp.Error == "authorization_pending" {
|
||||
fmt.Print(".")
|
||||
time.Sleep(PollInterval)
|
||||
continue
|
||||
}
|
||||
return nil, fmt.Errorf("authorization failed: %s", tokenResp.ErrorDescription)
|
||||
}
|
||||
|
||||
if tokenResp.AccessToken != "" {
|
||||
fmt.Println("\n✅ Authorization successful!")
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
time.Sleep(PollInterval)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("authorization timed out after %d attempts", MaxPollAttempts)
|
||||
}
|
||||
|
||||
// RefreshAccessToken refreshes an expired access token using the refresh token
|
||||
func (am *AuthManager) RefreshAccessToken(refreshToken string) (*types.TokenResponse, error) {
|
||||
tokenReq := types.TokenRequest{
|
||||
GrantType: "refresh_token",
|
||||
RefreshToken: refreshToken,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(tokenReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal refresh request: %w", err)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/token", am.config.DeviceAuthURL)
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create refresh request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := am.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to refresh token: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read refresh response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("token refresh failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var tokenResp types.TokenResponse
|
||||
if err := json.Unmarshal(body, &tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse refresh response: %w", err)
|
||||
}
|
||||
|
||||
if tokenResp.Error != "" {
|
||||
return nil, fmt.Errorf("token refresh failed: %s", tokenResp.ErrorDescription)
|
||||
}
|
||||
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
// SaveToken saves the authentication token to secure local storage
|
||||
func (am *AuthManager) SaveToken(token *types.AuthToken) error {
|
||||
if err := am.EnsureTokenStorageDir(); err != nil {
|
||||
return fmt.Errorf("failed to ensure token storage directory: %w", err)
|
||||
}
|
||||
|
||||
// Save main token file
|
||||
tokenPath := am.getTokenPath()
|
||||
jsonData, err := json.MarshalIndent(token, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal token: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(tokenPath, jsonData, 0600); err != nil {
|
||||
return fmt.Errorf("failed to save token: %w", err)
|
||||
}
|
||||
|
||||
// Also save refresh token separately for backup recovery
|
||||
if token.RefreshToken != "" {
|
||||
refreshTokenPath := filepath.Join(TokenStorageDir, RefreshTokenFile)
|
||||
if err := os.WriteFile(refreshTokenPath, []byte(token.RefreshToken), 0600); err != nil {
|
||||
// Don't fail if refresh token backup fails, just log
|
||||
fmt.Printf("Warning: Failed to save backup refresh token: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
} // LoadToken loads the authentication token from secure local storage
|
||||
func (am *AuthManager) LoadToken() (*types.AuthToken, error) {
|
||||
tokenPath := am.getTokenPath()
|
||||
|
||||
data, err := os.ReadFile(tokenPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read token file: %w", err)
|
||||
}
|
||||
|
||||
var token types.AuthToken
|
||||
if err := json.Unmarshal(data, &token); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token: %w", err)
|
||||
}
|
||||
|
||||
// Check if token is expired
|
||||
if time.Now().After(token.ExpiresAt.Add(-5 * time.Minute)) {
|
||||
return nil, fmt.Errorf("token is expired or expiring soon")
|
||||
}
|
||||
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// IsTokenExpired checks if a token needs refresh
|
||||
func (am *AuthManager) IsTokenExpired(token *types.AuthToken) bool {
|
||||
// Consider token expired if it expires within the next 5 minutes
|
||||
return time.Now().After(token.ExpiresAt.Add(-5 * time.Minute))
|
||||
}
|
||||
|
||||
// RegisterDevice performs the complete device registration flow
|
||||
func (am *AuthManager) RegisterDevice() (*types.AuthToken, error) {
|
||||
// Step 1: Start device authorization
|
||||
deviceAuth, err := am.StartDeviceAuthorization()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to start device authorization: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Please visit: %s\n", deviceAuth.VerificationURI)
|
||||
fmt.Printf("And enter code: %s\n", deviceAuth.UserCode)
|
||||
|
||||
// Step 2: Poll for token
|
||||
tokenResp, err := am.PollForToken(deviceAuth.DeviceCode)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get token: %w", err)
|
||||
}
|
||||
|
||||
// Step 3: Create token storage
|
||||
token := &types.AuthToken{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
TokenType: tokenResp.TokenType,
|
||||
ExpiresAt: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second),
|
||||
AgentID: tokenResp.AgentID,
|
||||
}
|
||||
|
||||
// Step 4: Save token
|
||||
if err := am.SaveToken(token); err != nil {
|
||||
return nil, fmt.Errorf("failed to save token: %w", err)
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// EnsureAuthenticated ensures the agent has a valid token, refreshing if necessary
|
||||
func (am *AuthManager) EnsureAuthenticated() (*types.AuthToken, error) {
|
||||
// Try to load existing token
|
||||
token, err := am.LoadToken()
|
||||
if err == nil && !am.IsTokenExpired(token) {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// Try to refresh with existing refresh token (even if access token is missing/expired)
|
||||
var refreshToken string
|
||||
if err == nil && token.RefreshToken != "" {
|
||||
// Use refresh token from loaded token
|
||||
refreshToken = token.RefreshToken
|
||||
} else {
|
||||
// Try to load refresh token from main token file even if load failed
|
||||
if existingToken, loadErr := am.loadTokenIgnoringExpiry(); loadErr == nil && existingToken.RefreshToken != "" {
|
||||
refreshToken = existingToken.RefreshToken
|
||||
} else {
|
||||
// Try to load refresh token from backup file
|
||||
if backupRefreshToken, backupErr := am.loadRefreshTokenFromBackup(); backupErr == nil {
|
||||
refreshToken = backupRefreshToken
|
||||
fmt.Println("🔄 Found backup refresh token, attempting to use it...")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if refreshToken != "" {
|
||||
fmt.Println("🔄 Attempting to refresh access token...")
|
||||
|
||||
refreshResp, refreshErr := am.RefreshAccessToken(refreshToken)
|
||||
if refreshErr == nil {
|
||||
// Get existing agent_id from current token or backup
|
||||
var agentID string
|
||||
if err == nil && token.AgentID != "" {
|
||||
agentID = token.AgentID
|
||||
} else if existingToken, loadErr := am.loadTokenIgnoringExpiry(); loadErr == nil {
|
||||
agentID = existingToken.AgentID
|
||||
}
|
||||
|
||||
// Create new token with refreshed values
|
||||
newToken := &types.AuthToken{
|
||||
AccessToken: refreshResp.AccessToken,
|
||||
RefreshToken: refreshToken, // Keep existing refresh token
|
||||
TokenType: refreshResp.TokenType,
|
||||
ExpiresAt: time.Now().Add(time.Duration(refreshResp.ExpiresIn) * time.Second),
|
||||
AgentID: agentID, // Preserve agent_id
|
||||
}
|
||||
|
||||
// Update refresh token if a new one was provided
|
||||
if refreshResp.RefreshToken != "" {
|
||||
newToken.RefreshToken = refreshResp.RefreshToken
|
||||
}
|
||||
|
||||
if saveErr := am.SaveToken(newToken); saveErr == nil {
|
||||
return newToken, nil
|
||||
}
|
||||
} else {
|
||||
fmt.Printf("⚠️ Token refresh failed: %v\n", refreshErr)
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println("📝 Initiating new device registration...")
|
||||
return am.RegisterDevice()
|
||||
}
|
||||
|
||||
// loadTokenIgnoringExpiry loads token file without checking expiry
|
||||
func (am *AuthManager) loadTokenIgnoringExpiry() (*types.AuthToken, error) {
|
||||
tokenPath := am.getTokenPath()
|
||||
|
||||
data, err := os.ReadFile(tokenPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read token file: %w", err)
|
||||
}
|
||||
|
||||
var token types.AuthToken
|
||||
if err := json.Unmarshal(data, &token); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token: %w", err)
|
||||
}
|
||||
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// loadRefreshTokenFromBackup tries to load refresh token from backup file
|
||||
func (am *AuthManager) loadRefreshTokenFromBackup() (string, error) {
|
||||
refreshTokenPath := filepath.Join(TokenStorageDir, RefreshTokenFile)
|
||||
|
||||
data, err := os.ReadFile(refreshTokenPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read refresh token backup: %w", err)
|
||||
}
|
||||
|
||||
refreshToken := strings.TrimSpace(string(data))
|
||||
if refreshToken == "" {
|
||||
return "", fmt.Errorf("refresh token backup is empty")
|
||||
}
|
||||
|
||||
return refreshToken, nil
|
||||
}
|
||||
|
||||
// GetCurrentAgentID retrieves the agent ID from cache or JWT token
|
||||
func (am *AuthManager) GetCurrentAgentID() (string, error) {
|
||||
// First try to read from local cache
|
||||
agentID, err := am.loadCachedAgentID()
|
||||
if err == nil && agentID != "" {
|
||||
return agentID, nil
|
||||
}
|
||||
|
||||
// Cache miss - extract from JWT token and cache it
|
||||
token, err := am.LoadToken()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to load token: %w", err)
|
||||
}
|
||||
|
||||
// Extract agent ID from JWT 'sub' field
|
||||
agentID, err = am.extractAgentIDFromJWT(token.AccessToken)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to extract agent ID from JWT: %w", err)
|
||||
}
|
||||
|
||||
// Cache the agent ID for future use
|
||||
if err := am.cacheAgentID(agentID); err != nil {
|
||||
// Log warning but don't fail - we still have the agent ID
|
||||
fmt.Printf("Warning: Failed to cache agent ID: %v\n", err)
|
||||
}
|
||||
|
||||
return agentID, nil
|
||||
}
|
||||
|
||||
// extractAgentIDFromJWT decodes the JWT token and extracts the agent ID from 'sub' field
|
||||
func (am *AuthManager) extractAgentIDFromJWT(tokenString string) (string, error) {
|
||||
// Basic JWT decoding without verification (since we trust Supabase)
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
return "", fmt.Errorf("invalid JWT token format")
|
||||
}
|
||||
|
||||
// Decode the payload (second part)
|
||||
payload := parts[1]
|
||||
|
||||
// Add padding if needed for base64 decoding
|
||||
for len(payload)%4 != 0 {
|
||||
payload += "="
|
||||
}
|
||||
|
||||
decoded, err := base64.URLEncoding.DecodeString(payload)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to decode JWT payload: %w", err)
|
||||
}
|
||||
|
||||
// Parse JSON payload
|
||||
var claims map[string]interface{}
|
||||
if err := json.Unmarshal(decoded, &claims); err != nil {
|
||||
return "", fmt.Errorf("failed to parse JWT claims: %w", err)
|
||||
}
|
||||
|
||||
// The agent ID is in the 'sub' field (subject)
|
||||
if agentID, ok := claims["sub"].(string); ok && agentID != "" {
|
||||
return agentID, nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("agent ID (sub) not found in JWT claims")
|
||||
}
|
||||
|
||||
// loadCachedAgentID reads the cached agent ID from local storage
|
||||
func (am *AuthManager) loadCachedAgentID() (string, error) {
|
||||
agentIDPath := filepath.Join(TokenStorageDir, "agent_id")
|
||||
|
||||
data, err := os.ReadFile(agentIDPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read cached agent ID: %w", err)
|
||||
}
|
||||
|
||||
agentID := strings.TrimSpace(string(data))
|
||||
if agentID == "" {
|
||||
return "", fmt.Errorf("cached agent ID is empty")
|
||||
}
|
||||
|
||||
return agentID, nil
|
||||
}
|
||||
|
||||
// cacheAgentID stores the agent ID in local cache
|
||||
func (am *AuthManager) cacheAgentID(agentID string) error {
|
||||
// Ensure the directory exists
|
||||
if err := am.EnsureTokenStorageDir(); err != nil {
|
||||
return fmt.Errorf("failed to ensure storage directory: %w", err)
|
||||
}
|
||||
|
||||
agentIDPath := filepath.Join(TokenStorageDir, "agent_id")
|
||||
|
||||
// Write agent ID to file with secure permissions
|
||||
if err := os.WriteFile(agentIDPath, []byte(agentID), 0600); err != nil {
|
||||
return fmt.Errorf("failed to write agent ID cache: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *AuthManager) getTokenPath() string {
|
||||
if am.config.TokenPath != "" {
|
||||
return am.config.TokenPath
|
||||
}
|
||||
return filepath.Join(TokenStorageDir, TokenStorageFile)
|
||||
}
|
||||
|
||||
func getHostname() string {
|
||||
if hostname, err := os.Hostname(); err == nil {
|
||||
return hostname
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
131
internal/config/config.go
Normal file
131
internal/config/config.go
Normal file
@@ -0,0 +1,131 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/joho/godotenv"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
// Supabase Configuration
|
||||
SupabaseProjectURL string
|
||||
|
||||
// Edge Function Endpoints (auto-generated from SupabaseProjectURL)
|
||||
DeviceAuthURL string
|
||||
AgentAuthURL string
|
||||
|
||||
// Agent Configuration
|
||||
TokenPath string
|
||||
MetricsInterval int
|
||||
|
||||
// Debug/Development
|
||||
Debug bool
|
||||
}
|
||||
|
||||
var DefaultConfig = Config{
|
||||
TokenPath: "./token.json",
|
||||
MetricsInterval: 30,
|
||||
Debug: false,
|
||||
}
|
||||
|
||||
// LoadConfig loads configuration from environment variables and .env file
|
||||
func LoadConfig() (*Config, error) {
|
||||
config := DefaultConfig
|
||||
|
||||
// Try to load .env file from current directory or parent directories
|
||||
envFile := findEnvFile()
|
||||
if envFile != "" {
|
||||
if err := godotenv.Load(envFile); err != nil {
|
||||
fmt.Printf("Warning: Could not load .env file from %s: %v\n", envFile, err)
|
||||
} else {
|
||||
fmt.Printf("Loaded configuration from %s\n", envFile)
|
||||
}
|
||||
}
|
||||
|
||||
// Load from environment variables
|
||||
if url := os.Getenv("SUPABASE_PROJECT_URL"); url != "" {
|
||||
config.SupabaseProjectURL = url
|
||||
}
|
||||
|
||||
if tokenPath := os.Getenv("TOKEN_PATH"); tokenPath != "" {
|
||||
config.TokenPath = tokenPath
|
||||
}
|
||||
|
||||
if debug := os.Getenv("DEBUG"); debug == "true" || debug == "1" {
|
||||
config.Debug = true
|
||||
}
|
||||
|
||||
// Auto-generate edge function URLs from project URL
|
||||
if config.SupabaseProjectURL != "" {
|
||||
config.DeviceAuthURL = fmt.Sprintf("%s/functions/v1/device-auth", config.SupabaseProjectURL)
|
||||
config.AgentAuthURL = fmt.Sprintf("%s/functions/v1/agent-auth-api", config.SupabaseProjectURL)
|
||||
}
|
||||
|
||||
// Validate required configuration
|
||||
if err := config.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("configuration validation failed: %w", err)
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
// Validate checks if all required configuration is present
|
||||
func (c *Config) Validate() error {
|
||||
var missing []string
|
||||
|
||||
if c.SupabaseProjectURL == "" {
|
||||
missing = append(missing, "SUPABASE_PROJECT_URL")
|
||||
}
|
||||
|
||||
if c.DeviceAuthURL == "" {
|
||||
missing = append(missing, "DEVICE_AUTH_URL (or SUPABASE_PROJECT_URL)")
|
||||
}
|
||||
|
||||
if c.AgentAuthURL == "" {
|
||||
missing = append(missing, "AGENT_AUTH_URL (or SUPABASE_PROJECT_URL)")
|
||||
}
|
||||
|
||||
if len(missing) > 0 {
|
||||
return fmt.Errorf("missing required environment variables: %s", strings.Join(missing, ", "))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// findEnvFile looks for .env file in current directory and parent directories
|
||||
func findEnvFile() string {
|
||||
dir, err := os.Getwd()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
for {
|
||||
envPath := filepath.Join(dir, ".env")
|
||||
if _, err := os.Stat(envPath); err == nil {
|
||||
return envPath
|
||||
}
|
||||
|
||||
parent := filepath.Dir(dir)
|
||||
if parent == dir {
|
||||
break
|
||||
}
|
||||
dir = parent
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// PrintConfig prints the current configuration (masking sensitive values)
|
||||
func (c *Config) PrintConfig() {
|
||||
if !c.Debug {
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println("Configuration:")
|
||||
fmt.Printf(" Supabase Project URL: %s\n", c.SupabaseProjectURL)
|
||||
fmt.Printf(" Metrics Interval: %d seconds\n", c.MetricsInterval)
|
||||
fmt.Printf(" Debug: %v\n", c.Debug)
|
||||
}
|
||||
90
internal/logging/logger.go
Normal file
90
internal/logging/logger.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"log/syslog"
|
||||
"os"
|
||||
)
|
||||
|
||||
type Logger struct {
|
||||
syslogWriter *syslog.Writer
|
||||
debugMode bool
|
||||
}
|
||||
|
||||
var defaultLogger *Logger
|
||||
|
||||
func init() {
|
||||
defaultLogger = NewLogger()
|
||||
}
|
||||
|
||||
func NewLogger() *Logger {
|
||||
l := &Logger{
|
||||
debugMode: os.Getenv("DEBUG") == "true",
|
||||
}
|
||||
|
||||
// Try to connect to syslog
|
||||
if writer, err := syslog.New(syslog.LOG_INFO|syslog.LOG_DAEMON, "nannyagentv2"); err == nil {
|
||||
l.syslogWriter = writer
|
||||
}
|
||||
|
||||
return l
|
||||
}
|
||||
|
||||
func (l *Logger) Info(format string, args ...interface{}) {
|
||||
msg := fmt.Sprintf(format, args...)
|
||||
if l.syslogWriter != nil {
|
||||
l.syslogWriter.Info(msg)
|
||||
}
|
||||
log.Printf("[INFO] %s", msg)
|
||||
}
|
||||
|
||||
func (l *Logger) Debug(format string, args ...interface{}) {
|
||||
if !l.debugMode {
|
||||
return
|
||||
}
|
||||
msg := fmt.Sprintf(format, args...)
|
||||
if l.syslogWriter != nil {
|
||||
l.syslogWriter.Debug(msg)
|
||||
}
|
||||
log.Printf("[DEBUG] %s", msg)
|
||||
}
|
||||
|
||||
func (l *Logger) Warning(format string, args ...interface{}) {
|
||||
msg := fmt.Sprintf(format, args...)
|
||||
if l.syslogWriter != nil {
|
||||
l.syslogWriter.Warning(msg)
|
||||
}
|
||||
log.Printf("[WARNING] %s", msg)
|
||||
}
|
||||
|
||||
func (l *Logger) Error(format string, args ...interface{}) {
|
||||
msg := fmt.Sprintf(format, args...)
|
||||
if l.syslogWriter != nil {
|
||||
l.syslogWriter.Err(msg)
|
||||
}
|
||||
log.Printf("[ERROR] %s", msg)
|
||||
}
|
||||
|
||||
func (l *Logger) Close() {
|
||||
if l.syslogWriter != nil {
|
||||
l.syslogWriter.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Global logging functions
|
||||
func Info(format string, args ...interface{}) {
|
||||
defaultLogger.Info(format, args...)
|
||||
}
|
||||
|
||||
func Debug(format string, args ...interface{}) {
|
||||
defaultLogger.Debug(format, args...)
|
||||
}
|
||||
|
||||
func Warning(format string, args ...interface{}) {
|
||||
defaultLogger.Warning(format, args...)
|
||||
}
|
||||
|
||||
func Error(format string, args ...interface{}) {
|
||||
defaultLogger.Error(format, args...)
|
||||
}
|
||||
318
internal/metrics/collector.go
Normal file
318
internal/metrics/collector.go
Normal file
@@ -0,0 +1,318 @@
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/shirou/gopsutil/v3/cpu"
|
||||
"github.com/shirou/gopsutil/v3/disk"
|
||||
"github.com/shirou/gopsutil/v3/host"
|
||||
"github.com/shirou/gopsutil/v3/load"
|
||||
"github.com/shirou/gopsutil/v3/mem"
|
||||
psnet "github.com/shirou/gopsutil/v3/net"
|
||||
|
||||
"nannyagentv2/internal/types"
|
||||
)
|
||||
|
||||
// Collector handles system metrics collection
|
||||
type Collector struct {
|
||||
agentVersion string
|
||||
}
|
||||
|
||||
// NewCollector creates a new metrics collector
|
||||
func NewCollector(agentVersion string) *Collector {
|
||||
return &Collector{
|
||||
agentVersion: agentVersion,
|
||||
}
|
||||
}
|
||||
|
||||
// GatherSystemMetrics collects comprehensive system metrics
|
||||
func (c *Collector) GatherSystemMetrics() (*types.SystemMetrics, error) {
|
||||
metrics := &types.SystemMetrics{
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
|
||||
// System Information
|
||||
if hostInfo, err := host.Info(); err == nil {
|
||||
metrics.Hostname = hostInfo.Hostname
|
||||
metrics.Platform = hostInfo.Platform
|
||||
metrics.PlatformFamily = hostInfo.PlatformFamily
|
||||
metrics.PlatformVersion = hostInfo.PlatformVersion
|
||||
metrics.KernelVersion = hostInfo.KernelVersion
|
||||
metrics.KernelArch = hostInfo.KernelArch
|
||||
}
|
||||
|
||||
// CPU Metrics
|
||||
if percentages, err := cpu.Percent(time.Second, false); err == nil && len(percentages) > 0 {
|
||||
metrics.CPUUsage = math.Round(percentages[0]*100) / 100
|
||||
}
|
||||
|
||||
if cpuInfo, err := cpu.Info(); err == nil && len(cpuInfo) > 0 {
|
||||
metrics.CPUCores = len(cpuInfo)
|
||||
metrics.CPUModel = cpuInfo[0].ModelName
|
||||
}
|
||||
|
||||
// Memory Metrics
|
||||
if memInfo, err := mem.VirtualMemory(); err == nil {
|
||||
metrics.MemoryUsage = math.Round(float64(memInfo.Used)/(1024*1024)*100) / 100 // MB
|
||||
metrics.MemoryTotal = memInfo.Total
|
||||
metrics.MemoryUsed = memInfo.Used
|
||||
metrics.MemoryFree = memInfo.Free
|
||||
metrics.MemoryAvailable = memInfo.Available
|
||||
}
|
||||
|
||||
if swapInfo, err := mem.SwapMemory(); err == nil {
|
||||
metrics.SwapTotal = swapInfo.Total
|
||||
metrics.SwapUsed = swapInfo.Used
|
||||
metrics.SwapFree = swapInfo.Free
|
||||
}
|
||||
|
||||
// Disk Metrics
|
||||
if diskInfo, err := disk.Usage("/"); err == nil {
|
||||
metrics.DiskUsage = math.Round(diskInfo.UsedPercent*100) / 100
|
||||
metrics.DiskTotal = diskInfo.Total
|
||||
metrics.DiskUsed = diskInfo.Used
|
||||
metrics.DiskFree = diskInfo.Free
|
||||
}
|
||||
|
||||
// Load Averages
|
||||
if loadAvg, err := load.Avg(); err == nil {
|
||||
metrics.LoadAvg1 = math.Round(loadAvg.Load1*100) / 100
|
||||
metrics.LoadAvg5 = math.Round(loadAvg.Load5*100) / 100
|
||||
metrics.LoadAvg15 = math.Round(loadAvg.Load15*100) / 100
|
||||
}
|
||||
|
||||
// Process Count (simplified - using a constant for now)
|
||||
// Note: gopsutil doesn't have host.Processes(), would need process.Processes()
|
||||
metrics.ProcessCount = 0 // Placeholder
|
||||
|
||||
// Network Metrics
|
||||
netIn, netOut := c.getNetworkStats()
|
||||
metrics.NetworkInKbps = netIn
|
||||
metrics.NetworkOutKbps = netOut
|
||||
|
||||
if netIOCounters, err := psnet.IOCounters(false); err == nil && len(netIOCounters) > 0 {
|
||||
netIO := netIOCounters[0]
|
||||
metrics.NetworkInBytes = netIO.BytesRecv
|
||||
metrics.NetworkOutBytes = netIO.BytesSent
|
||||
}
|
||||
|
||||
// IP Address and Location
|
||||
metrics.IPAddress = c.getIPAddress()
|
||||
metrics.Location = c.getLocation() // Placeholder
|
||||
|
||||
// Filesystem Information
|
||||
metrics.FilesystemInfo = c.getFilesystemInfo()
|
||||
|
||||
// Block Devices
|
||||
metrics.BlockDevices = c.getBlockDevices()
|
||||
|
||||
return metrics, nil
|
||||
}
|
||||
|
||||
// getNetworkStats returns network input/output rates in Kbps
|
||||
func (c *Collector) getNetworkStats() (float64, float64) {
|
||||
netIOCounters, err := psnet.IOCounters(false)
|
||||
if err != nil || len(netIOCounters) == 0 {
|
||||
return 0.0, 0.0
|
||||
}
|
||||
|
||||
// Use the first interface for aggregate stats
|
||||
netIO := netIOCounters[0]
|
||||
|
||||
// Convert bytes to kilobits per second (simplified - cumulative bytes to kilobits)
|
||||
netInKbps := float64(netIO.BytesRecv) * 8 / 1024
|
||||
netOutKbps := float64(netIO.BytesSent) * 8 / 1024
|
||||
|
||||
return netInKbps, netOutKbps
|
||||
}
|
||||
|
||||
// getIPAddress returns the primary IP address of the system
|
||||
func (c *Collector) getIPAddress() string {
|
||||
interfaces, err := psnet.Interfaces()
|
||||
if err != nil {
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
for _, iface := range interfaces {
|
||||
if len(iface.Addrs) > 0 && !strings.Contains(iface.Addrs[0].Addr, "127.0.0.1") {
|
||||
return strings.Split(iface.Addrs[0].Addr, "/")[0] // Remove CIDR if present
|
||||
}
|
||||
}
|
||||
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// getLocation returns basic location information (placeholder)
|
||||
func (c *Collector) getLocation() string {
|
||||
return "unknown" // Would integrate with GeoIP service
|
||||
}
|
||||
|
||||
// getFilesystemInfo returns information about mounted filesystems
|
||||
func (c *Collector) getFilesystemInfo() []types.FilesystemInfo {
|
||||
partitions, err := disk.Partitions(false)
|
||||
if err != nil {
|
||||
return []types.FilesystemInfo{}
|
||||
}
|
||||
|
||||
var filesystems []types.FilesystemInfo
|
||||
for _, partition := range partitions {
|
||||
usage, err := disk.Usage(partition.Mountpoint)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
fs := types.FilesystemInfo{
|
||||
Mountpoint: partition.Mountpoint,
|
||||
Fstype: partition.Fstype,
|
||||
Total: usage.Total,
|
||||
Used: usage.Used,
|
||||
Free: usage.Free,
|
||||
UsagePercent: math.Round(usage.UsedPercent*100) / 100,
|
||||
}
|
||||
filesystems = append(filesystems, fs)
|
||||
}
|
||||
|
||||
return filesystems
|
||||
}
|
||||
|
||||
// getBlockDevices returns information about block devices
|
||||
func (c *Collector) getBlockDevices() []types.BlockDevice {
|
||||
partitions, err := disk.Partitions(true)
|
||||
if err != nil {
|
||||
return []types.BlockDevice{}
|
||||
}
|
||||
|
||||
var devices []types.BlockDevice
|
||||
deviceMap := make(map[string]bool)
|
||||
|
||||
for _, partition := range partitions {
|
||||
// Only include actual block devices
|
||||
if strings.HasPrefix(partition.Device, "/dev/") {
|
||||
deviceName := partition.Device
|
||||
if !deviceMap[deviceName] {
|
||||
deviceMap[deviceName] = true
|
||||
|
||||
device := types.BlockDevice{
|
||||
Name: deviceName,
|
||||
Model: "unknown",
|
||||
Size: 0,
|
||||
SerialNumber: "unknown",
|
||||
}
|
||||
devices = append(devices, device)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return devices
|
||||
}
|
||||
|
||||
// SendMetrics sends system metrics to the agent-auth-api endpoint
|
||||
func (c *Collector) SendMetrics(agentAuthURL, accessToken, agentID string, metrics *types.SystemMetrics) error {
|
||||
// Create flattened metrics request for agent-auth-api
|
||||
metricsReq := c.CreateMetricsRequest(agentID, metrics)
|
||||
|
||||
return c.sendMetricsRequest(agentAuthURL, accessToken, metricsReq)
|
||||
}
|
||||
|
||||
// CreateMetricsRequest converts SystemMetrics to the flattened format expected by agent-auth-api
|
||||
func (c *Collector) CreateMetricsRequest(agentID string, systemMetrics *types.SystemMetrics) *types.MetricsRequest {
|
||||
return &types.MetricsRequest{
|
||||
AgentID: agentID,
|
||||
CPUUsage: systemMetrics.CPUUsage,
|
||||
MemoryUsage: systemMetrics.MemoryUsage,
|
||||
DiskUsage: systemMetrics.DiskUsage,
|
||||
NetworkInKbps: systemMetrics.NetworkInKbps,
|
||||
NetworkOutKbps: systemMetrics.NetworkOutKbps,
|
||||
IPAddress: systemMetrics.IPAddress,
|
||||
Location: systemMetrics.Location,
|
||||
AgentVersion: c.agentVersion,
|
||||
KernelVersion: systemMetrics.KernelVersion,
|
||||
DeviceFingerprint: c.generateDeviceFingerprint(systemMetrics),
|
||||
LoadAverages: map[string]float64{
|
||||
"load1": systemMetrics.LoadAvg1,
|
||||
"load5": systemMetrics.LoadAvg5,
|
||||
"load15": systemMetrics.LoadAvg15,
|
||||
},
|
||||
OSInfo: map[string]string{
|
||||
"cpu_cores": fmt.Sprintf("%d", systemMetrics.CPUCores),
|
||||
"memory": fmt.Sprintf("%.1fGi", float64(systemMetrics.MemoryTotal)/(1024*1024*1024)),
|
||||
"uptime": "unknown", // Will be calculated by the server or client
|
||||
"platform": systemMetrics.Platform,
|
||||
"platform_family": systemMetrics.PlatformFamily,
|
||||
"platform_version": systemMetrics.PlatformVersion,
|
||||
"kernel_version": systemMetrics.KernelVersion,
|
||||
"kernel_arch": systemMetrics.KernelArch,
|
||||
},
|
||||
FilesystemInfo: systemMetrics.FilesystemInfo,
|
||||
BlockDevices: systemMetrics.BlockDevices,
|
||||
NetworkStats: map[string]uint64{
|
||||
"bytes_sent": systemMetrics.NetworkOutBytes,
|
||||
"bytes_recv": systemMetrics.NetworkInBytes,
|
||||
"total_bytes": systemMetrics.NetworkInBytes + systemMetrics.NetworkOutBytes,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// sendMetricsRequest sends the metrics request to the agent-auth-api
|
||||
func (c *Collector) sendMetricsRequest(agentAuthURL, accessToken string, metricsReq *types.MetricsRequest) error {
|
||||
// Wrap metrics in the expected payload structure
|
||||
payload := map[string]interface{}{
|
||||
"metrics": metricsReq,
|
||||
"timestamp": time.Now().UTC().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal metrics: %w", err)
|
||||
}
|
||||
|
||||
// Send to /metrics endpoint
|
||||
metricsURL := fmt.Sprintf("%s/metrics", agentAuthURL)
|
||||
req, err := http.NewRequest("POST", metricsURL, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken))
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send metrics: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Read response
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
// Check response status
|
||||
if resp.StatusCode == http.StatusUnauthorized {
|
||||
return fmt.Errorf("unauthorized")
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("metrics request failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateDeviceFingerprint creates a unique device identifier
|
||||
func (c *Collector) generateDeviceFingerprint(metrics *types.SystemMetrics) string {
|
||||
fingerprint := fmt.Sprintf("%s-%s-%s", metrics.Hostname, metrics.Platform, metrics.KernelVersion)
|
||||
hasher := sha256.New()
|
||||
hasher.Write([]byte(fingerprint))
|
||||
return fmt.Sprintf("%x", hasher.Sum(nil))[:16]
|
||||
}
|
||||
337
internal/types/types.go
Normal file
337
internal/types/types.go
Normal file
@@ -0,0 +1,337 @@
|
||||
package types
|
||||
|
||||
import "time"
|
||||
|
||||
// SystemMetrics represents comprehensive system performance metrics
|
||||
type SystemMetrics struct {
|
||||
// System Information
|
||||
Hostname string `json:"hostname"`
|
||||
Platform string `json:"platform"`
|
||||
PlatformFamily string `json:"platform_family"`
|
||||
PlatformVersion string `json:"platform_version"`
|
||||
KernelVersion string `json:"kernel_version"`
|
||||
KernelArch string `json:"kernel_arch"`
|
||||
|
||||
// CPU Metrics
|
||||
CPUUsage float64 `json:"cpu_usage"`
|
||||
CPUCores int `json:"cpu_cores"`
|
||||
CPUModel string `json:"cpu_model"`
|
||||
|
||||
// Memory Metrics
|
||||
MemoryUsage float64 `json:"memory_usage"`
|
||||
MemoryTotal uint64 `json:"memory_total"`
|
||||
MemoryUsed uint64 `json:"memory_used"`
|
||||
MemoryFree uint64 `json:"memory_free"`
|
||||
MemoryAvailable uint64 `json:"memory_available"`
|
||||
SwapTotal uint64 `json:"swap_total"`
|
||||
SwapUsed uint64 `json:"swap_used"`
|
||||
SwapFree uint64 `json:"swap_free"`
|
||||
|
||||
// Disk Metrics
|
||||
DiskUsage float64 `json:"disk_usage"`
|
||||
DiskTotal uint64 `json:"disk_total"`
|
||||
DiskUsed uint64 `json:"disk_used"`
|
||||
DiskFree uint64 `json:"disk_free"`
|
||||
|
||||
// Network Metrics
|
||||
NetworkInKbps float64 `json:"network_in_kbps"`
|
||||
NetworkOutKbps float64 `json:"network_out_kbps"`
|
||||
NetworkInBytes uint64 `json:"network_in_bytes"`
|
||||
NetworkOutBytes uint64 `json:"network_out_bytes"`
|
||||
|
||||
// System Load
|
||||
LoadAvg1 float64 `json:"load_avg_1"`
|
||||
LoadAvg5 float64 `json:"load_avg_5"`
|
||||
LoadAvg15 float64 `json:"load_avg_15"`
|
||||
|
||||
// Process Information
|
||||
ProcessCount int `json:"process_count"`
|
||||
|
||||
// Network Information
|
||||
IPAddress string `json:"ip_address"`
|
||||
Location string `json:"location"`
|
||||
|
||||
// Filesystem Information
|
||||
FilesystemInfo []FilesystemInfo `json:"filesystem_info"`
|
||||
BlockDevices []BlockDevice `json:"block_devices"`
|
||||
|
||||
// Timestamp
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
|
||||
// FilesystemInfo represents individual filesystem statistics
|
||||
type FilesystemInfo struct {
|
||||
Mountpoint string `json:"mountpoint"`
|
||||
Fstype string `json:"fstype"`
|
||||
Total uint64 `json:"total"`
|
||||
Used uint64 `json:"used"`
|
||||
Free uint64 `json:"free"`
|
||||
UsagePercent float64 `json:"usage_percent"`
|
||||
}
|
||||
|
||||
// BlockDevice represents block device information
|
||||
type BlockDevice struct {
|
||||
Name string `json:"name"`
|
||||
Size uint64 `json:"size"`
|
||||
Model string `json:"model"`
|
||||
SerialNumber string `json:"serial_number"`
|
||||
}
|
||||
|
||||
// NetworkStats represents detailed network interface statistics
|
||||
type NetworkStats struct {
|
||||
InterfaceName string `json:"interface_name"`
|
||||
BytesSent uint64 `json:"bytes_sent"`
|
||||
BytesRecv uint64 `json:"bytes_recv"`
|
||||
PacketsSent uint64 `json:"packets_sent"`
|
||||
PacketsRecv uint64 `json:"packets_recv"`
|
||||
ErrorsIn uint64 `json:"errors_in"`
|
||||
ErrorsOut uint64 `json:"errors_out"`
|
||||
DropsIn uint64 `json:"drops_in"`
|
||||
DropsOut uint64 `json:"drops_out"`
|
||||
}
|
||||
|
||||
// AuthToken represents the authentication token structure
|
||||
type AuthToken struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
TokenType string `json:"token_type"`
|
||||
AgentID string `json:"agent_id"`
|
||||
}
|
||||
|
||||
// DeviceAuthRequest represents the device authorization request
|
||||
type DeviceAuthRequest struct {
|
||||
ClientID string `json:"client_id"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
}
|
||||
|
||||
// DeviceAuthResponse represents the device authorization response
|
||||
type DeviceAuthResponse struct {
|
||||
DeviceCode string `json:"device_code"`
|
||||
UserCode string `json:"user_code"`
|
||||
VerificationURI string `json:"verification_uri"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
Interval int `json:"interval"`
|
||||
}
|
||||
|
||||
// TokenRequest represents the token request for device flow
|
||||
type TokenRequest struct {
|
||||
GrantType string `json:"grant_type"`
|
||||
DeviceCode string `json:"device_code,omitempty"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
ClientID string `json:"client_id,omitempty"`
|
||||
}
|
||||
|
||||
// TokenResponse represents the token response
|
||||
type TokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
AgentID string `json:"agent_id,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
ErrorDescription string `json:"error_description,omitempty"`
|
||||
}
|
||||
|
||||
// HeartbeatRequest represents the agent heartbeat request
|
||||
type HeartbeatRequest struct {
|
||||
AgentID string `json:"agent_id"`
|
||||
Status string `json:"status"`
|
||||
Metrics SystemMetrics `json:"metrics"`
|
||||
}
|
||||
|
||||
// MetricsRequest represents the flattened metrics payload expected by agent-auth-api
|
||||
type MetricsRequest struct {
|
||||
// Agent identification
|
||||
AgentID string `json:"agent_id"`
|
||||
|
||||
// Basic metrics
|
||||
CPUUsage float64 `json:"cpu_usage"`
|
||||
MemoryUsage float64 `json:"memory_usage"`
|
||||
DiskUsage float64 `json:"disk_usage"`
|
||||
|
||||
// Network metrics
|
||||
NetworkInKbps float64 `json:"network_in_kbps"`
|
||||
NetworkOutKbps float64 `json:"network_out_kbps"`
|
||||
|
||||
// System information
|
||||
IPAddress string `json:"ip_address"`
|
||||
Location string `json:"location"`
|
||||
AgentVersion string `json:"agent_version"`
|
||||
KernelVersion string `json:"kernel_version"`
|
||||
DeviceFingerprint string `json:"device_fingerprint"`
|
||||
|
||||
// Structured data (JSON fields in database)
|
||||
LoadAverages map[string]float64 `json:"load_averages"`
|
||||
OSInfo map[string]string `json:"os_info"`
|
||||
FilesystemInfo []FilesystemInfo `json:"filesystem_info"`
|
||||
BlockDevices []BlockDevice `json:"block_devices"`
|
||||
NetworkStats map[string]uint64 `json:"network_stats"`
|
||||
}
|
||||
|
||||
// eBPF related types
|
||||
type EBPFEvent struct {
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
EventType string `json:"event_type"`
|
||||
ProcessID int `json:"process_id"`
|
||||
ProcessName string `json:"process_name"`
|
||||
UserID int `json:"user_id"`
|
||||
Data map[string]interface{} `json:"data"`
|
||||
}
|
||||
|
||||
type EBPFTrace struct {
|
||||
TraceID string `json:"trace_id"`
|
||||
StartTime time.Time `json:"start_time"`
|
||||
EndTime time.Time `json:"end_time"`
|
||||
Capability string `json:"capability"`
|
||||
Events []EBPFEvent `json:"events"`
|
||||
Summary string `json:"summary"`
|
||||
EventCount int `json:"event_count"`
|
||||
ProcessList []string `json:"process_list"`
|
||||
}
|
||||
|
||||
type EBPFRequest struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"` // "tracepoint", "kprobe", "kretprobe"
|
||||
Target string `json:"target"` // tracepoint path or function name
|
||||
Duration int `json:"duration"` // seconds
|
||||
Filters map[string]string `json:"filters,omitempty"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
type NetworkEvent struct {
|
||||
Timestamp uint64 `json:"timestamp"`
|
||||
PID uint32 `json:"pid"`
|
||||
TID uint32 `json:"tid"`
|
||||
UID uint32 `json:"uid"`
|
||||
EventType string `json:"event_type"`
|
||||
Comm [16]byte `json:"-"`
|
||||
CommStr string `json:"comm"`
|
||||
}
|
||||
|
||||
// Agent types
|
||||
type DiagnosticResponse struct {
|
||||
ResponseType string `json:"response_type"`
|
||||
Reasoning string `json:"reasoning"`
|
||||
Commands []Command `json:"commands"`
|
||||
}
|
||||
|
||||
type ResolutionResponse struct {
|
||||
ResponseType string `json:"response_type"`
|
||||
RootCause string `json:"root_cause"`
|
||||
ResolutionPlan string `json:"resolution_plan"`
|
||||
Confidence string `json:"confidence"`
|
||||
}
|
||||
|
||||
type Command struct {
|
||||
ID string `json:"id"`
|
||||
Command string `json:"command"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
type CommandResult struct {
|
||||
ID string `json:"id"`
|
||||
Command string `json:"command"`
|
||||
Description string `json:"description"`
|
||||
Output string `json:"output"`
|
||||
ExitCode int `json:"exit_code"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type EBPFEnhancedDiagnosticResponse struct {
|
||||
ResponseType string `json:"response_type"`
|
||||
Reasoning string `json:"reasoning"`
|
||||
Commands []Command `json:"commands"`
|
||||
EBPFPrograms []EBPFRequest `json:"ebpf_programs"`
|
||||
NextActions []string `json:"next_actions,omitempty"`
|
||||
}
|
||||
|
||||
type TensorZeroRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []map[string]interface{} `json:"messages"`
|
||||
EpisodeID string `json:"tensorzero::episode_id,omitempty"`
|
||||
}
|
||||
|
||||
type TensorZeroResponse struct {
|
||||
Choices []map[string]interface{} `json:"choices"`
|
||||
EpisodeID string `json:"episode_id"`
|
||||
}
|
||||
|
||||
// WebSocket types
|
||||
type WebSocketMessage struct {
|
||||
Type string `json:"type"`
|
||||
Data interface{} `json:"data"`
|
||||
}
|
||||
|
||||
type InvestigationTask struct {
|
||||
TaskID string `json:"task_id"`
|
||||
InvestigationID string `json:"investigation_id"`
|
||||
AgentID string `json:"agent_id"`
|
||||
DiagnosticPayload map[string]interface{} `json:"diagnostic_payload"`
|
||||
EpisodeID string `json:"episode_id,omitempty"`
|
||||
}
|
||||
|
||||
type TaskResult struct {
|
||||
TaskID string `json:"task_id"`
|
||||
Success bool `json:"success"`
|
||||
CommandResults map[string]interface{} `json:"command_results,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type HeartbeatData struct {
|
||||
AgentID string `json:"agent_id"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
// Investigation server types
|
||||
type InvestigationRequest struct {
|
||||
Issue string `json:"issue"`
|
||||
AgentID string `json:"agent_id"`
|
||||
EpisodeID string `json:"episode_id,omitempty"`
|
||||
Timestamp string `json:"timestamp,omitempty"`
|
||||
Priority string `json:"priority,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
}
|
||||
|
||||
type InvestigationResponse struct {
|
||||
Status string `json:"status"`
|
||||
Message string `json:"message"`
|
||||
Results map[string]interface{} `json:"results,omitempty"`
|
||||
AgentID string `json:"agent_id"`
|
||||
Timestamp string `json:"timestamp"`
|
||||
EpisodeID string `json:"episode_id,omitempty"`
|
||||
Investigation *PendingInvestigation `json:"investigation,omitempty"`
|
||||
}
|
||||
|
||||
type PendingInvestigation struct {
|
||||
ID string `json:"id"`
|
||||
Issue string `json:"issue"`
|
||||
AgentID string `json:"agent_id"`
|
||||
Status string `json:"status"`
|
||||
DiagnosticPayload map[string]interface{} `json:"diagnostic_payload"`
|
||||
CommandResults map[string]interface{} `json:"command_results,omitempty"`
|
||||
EpisodeID *string `json:"episode_id,omitempty"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
StartedAt *string `json:"started_at,omitempty"`
|
||||
CompletedAt *string `json:"completed_at,omitempty"`
|
||||
ErrorMessage *string `json:"error_message,omitempty"`
|
||||
}
|
||||
|
||||
// System types
|
||||
type SystemInfo struct {
|
||||
Hostname string `json:"hostname"`
|
||||
Platform string `json:"platform"`
|
||||
PlatformInfo map[string]string `json:"platform_info"`
|
||||
KernelVersion string `json:"kernel_version"`
|
||||
Uptime string `json:"uptime"`
|
||||
LoadAverage []float64 `json:"load_average"`
|
||||
CPUInfo map[string]string `json:"cpu_info"`
|
||||
MemoryInfo map[string]string `json:"memory_info"`
|
||||
DiskInfo []map[string]string `json:"disk_info"`
|
||||
}
|
||||
|
||||
// Executor types
|
||||
type CommandExecutor struct {
|
||||
timeout time.Duration
|
||||
}
|
||||
527
investigation_server.go
Normal file
527
investigation_server.go
Normal file
@@ -0,0 +1,527 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"nannyagentv2/internal/auth"
|
||||
"nannyagentv2/internal/metrics"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
// InvestigationRequest represents a request from Supabase to start an investigation
|
||||
type InvestigationRequest struct {
|
||||
InvestigationID string `json:"investigation_id"`
|
||||
ApplicationGroup string `json:"application_group"`
|
||||
Issue string `json:"issue"`
|
||||
Context map[string]string `json:"context"`
|
||||
Priority string `json:"priority"`
|
||||
InitiatedBy string `json:"initiated_by"`
|
||||
}
|
||||
|
||||
// InvestigationResponse represents the agent's response to an investigation
|
||||
type InvestigationResponse struct {
|
||||
AgentID string `json:"agent_id"`
|
||||
InvestigationID string `json:"investigation_id"`
|
||||
Status string `json:"status"`
|
||||
Commands []CommandResult `json:"commands,omitempty"`
|
||||
AIResponse string `json:"ai_response,omitempty"`
|
||||
EpisodeID string `json:"episode_id,omitempty"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// InvestigationServer handles reverse investigation requests from Supabase
|
||||
type InvestigationServer struct {
|
||||
agent *LinuxDiagnosticAgent // Original agent for direct user interactions
|
||||
applicationAgent *LinuxDiagnosticAgent // Separate agent for application-initiated investigations
|
||||
port string
|
||||
agentID string
|
||||
metricsCollector *metrics.Collector
|
||||
authManager *auth.AuthManager
|
||||
startTime time.Time
|
||||
supabaseURL string
|
||||
}
|
||||
|
||||
// NewInvestigationServer creates a new investigation server
|
||||
func NewInvestigationServer(agent *LinuxDiagnosticAgent, authManager *auth.AuthManager) *InvestigationServer {
|
||||
port := os.Getenv("AGENT_PORT")
|
||||
if port == "" {
|
||||
port = "1234"
|
||||
}
|
||||
|
||||
// Get agent ID from authentication system
|
||||
var agentID string
|
||||
if authManager != nil {
|
||||
if id, err := authManager.GetCurrentAgentID(); err == nil {
|
||||
agentID = id
|
||||
|
||||
} else {
|
||||
fmt.Printf("❌ Failed to get agent ID from auth manager: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to environment variable or generate one if auth fails
|
||||
if agentID == "" {
|
||||
agentID = os.Getenv("AGENT_ID")
|
||||
if agentID == "" {
|
||||
agentID = fmt.Sprintf("agent-%d", time.Now().Unix())
|
||||
}
|
||||
}
|
||||
|
||||
// Create metrics collector
|
||||
metricsCollector := metrics.NewCollector("v2.0.0")
|
||||
|
||||
// Create a separate agent for application-initiated investigations
|
||||
applicationAgent := NewLinuxDiagnosticAgent()
|
||||
// Override the model to use the application-specific function
|
||||
applicationAgent.model = "tensorzero::function_name::diagnose_and_heal_application"
|
||||
|
||||
return &InvestigationServer{
|
||||
agent: agent,
|
||||
applicationAgent: applicationAgent,
|
||||
port: port,
|
||||
agentID: agentID,
|
||||
metricsCollector: metricsCollector,
|
||||
authManager: authManager,
|
||||
startTime: time.Now(),
|
||||
supabaseURL: os.Getenv("SUPABASE_PROJECT_URL"),
|
||||
}
|
||||
}
|
||||
|
||||
// DiagnoseIssueForApplication handles diagnostic requests initiated from application/portal
|
||||
func (s *InvestigationServer) DiagnoseIssueForApplication(issue, episodeID string) error {
|
||||
// Set the episode ID on the application agent for continuity
|
||||
s.applicationAgent.episodeID = episodeID
|
||||
return s.applicationAgent.DiagnoseIssue(issue)
|
||||
}
|
||||
|
||||
// Start starts the HTTP server and realtime polling for investigation requests
|
||||
func (s *InvestigationServer) Start() error {
|
||||
mux := http.NewServeMux()
|
||||
|
||||
// Health check endpoint
|
||||
mux.HandleFunc("/health", s.handleHealth)
|
||||
|
||||
// Investigation endpoint
|
||||
mux.HandleFunc("/investigate", s.handleInvestigation)
|
||||
|
||||
// Agent status endpoint
|
||||
mux.HandleFunc("/status", s.handleStatus)
|
||||
|
||||
// Start realtime polling for backend-initiated investigations
|
||||
if s.supabaseURL != "" && s.authManager != nil {
|
||||
go s.startRealtimePolling()
|
||||
fmt.Printf("🔄 Realtime investigation polling enabled\n")
|
||||
} else {
|
||||
fmt.Printf("⚠️ Realtime investigation polling disabled (missing Supabase config or auth)\n")
|
||||
}
|
||||
|
||||
server := &http.Server{
|
||||
Addr: ":" + s.port,
|
||||
Handler: mux,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
fmt.Printf("🔍 Investigation server started on port %s (Agent ID: %s)\n", s.port, s.agentID)
|
||||
return server.ListenAndServe()
|
||||
}
|
||||
|
||||
// handleHealth responds to health check requests
|
||||
func (s *InvestigationServer) handleHealth(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
response := map[string]interface{}{
|
||||
"status": "healthy",
|
||||
"agent_id": s.agentID,
|
||||
"timestamp": time.Now(),
|
||||
"version": "v2.0.0",
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
// handleStatus responds with agent status and capabilities
|
||||
func (s *InvestigationServer) handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// Collect current system metrics
|
||||
systemMetrics, err := s.metricsCollector.GatherSystemMetrics()
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to collect metrics: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Convert to metrics request format for consistent data structure
|
||||
metricsReq := s.metricsCollector.CreateMetricsRequest(s.agentID, systemMetrics)
|
||||
|
||||
response := map[string]interface{}{
|
||||
"agent_id": s.agentID,
|
||||
"status": "ready",
|
||||
"capabilities": []string{"system_diagnostics", "ebpf_monitoring", "command_execution", "ai_analysis"},
|
||||
"system_info": map[string]interface{}{
|
||||
"os": fmt.Sprintf("%s %s", metricsReq.OSInfo["platform"], metricsReq.OSInfo["platform_version"]),
|
||||
"kernel": metricsReq.KernelVersion,
|
||||
"architecture": metricsReq.OSInfo["kernel_arch"],
|
||||
"cpu_cores": metricsReq.OSInfo["cpu_cores"],
|
||||
"memory": metricsReq.MemoryUsage,
|
||||
"private_ips": metricsReq.IPAddress,
|
||||
"load_average": fmt.Sprintf("%.2f, %.2f, %.2f",
|
||||
metricsReq.LoadAverages["load1"],
|
||||
metricsReq.LoadAverages["load5"],
|
||||
metricsReq.LoadAverages["load15"]),
|
||||
"disk_usage": fmt.Sprintf("Root: %.0fG/%.0fG (%.0f%% used)",
|
||||
float64(metricsReq.FilesystemInfo[0].Used)/1024/1024/1024,
|
||||
float64(metricsReq.FilesystemInfo[0].Total)/1024/1024/1024,
|
||||
metricsReq.DiskUsage),
|
||||
},
|
||||
"uptime": time.Since(s.startTime),
|
||||
"last_contact": time.Now(),
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
// sendCommandResultsToTensorZero sends command results back to TensorZero and continues conversation
|
||||
func (s *InvestigationServer) sendCommandResultsToTensorZero(diagnosticResp DiagnosticResponse, commandResults []CommandResult) (interface{}, error) {
|
||||
// Build conversation history like in agent.go
|
||||
messages := []openai.ChatCompletionMessage{
|
||||
// Add the original diagnostic response as assistant message
|
||||
{
|
||||
Role: openai.ChatMessageRoleAssistant,
|
||||
Content: fmt.Sprintf(`{"response_type":"diagnostic","reasoning":"%s","commands":%s}`,
|
||||
diagnosticResp.Reasoning,
|
||||
mustMarshalJSON(diagnosticResp.Commands)),
|
||||
},
|
||||
}
|
||||
|
||||
// Add command results as user message (same as agent.go does)
|
||||
resultsJSON, err := json.MarshalIndent(commandResults, "", " ")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal command results: %w", err)
|
||||
}
|
||||
|
||||
messages = append(messages, openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: string(resultsJSON),
|
||||
})
|
||||
|
||||
// Send to TensorZero via application agent's sendRequest method
|
||||
fmt.Printf("🔄 Sending command results to TensorZero for analysis...\n")
|
||||
response, err := s.applicationAgent.sendRequest(messages)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request to TensorZero: %w", err)
|
||||
}
|
||||
|
||||
if len(response.Choices) == 0 {
|
||||
return nil, fmt.Errorf("no choices in TensorZero response")
|
||||
}
|
||||
|
||||
content := response.Choices[0].Message.Content
|
||||
fmt.Printf("🤖 TensorZero continued analysis:\n%s\n", content)
|
||||
|
||||
// Try to parse the response to determine if it's diagnostic or resolution
|
||||
var diagnosticNextResp DiagnosticResponse
|
||||
var resolutionResp ResolutionResponse
|
||||
|
||||
// Check if it's another diagnostic response
|
||||
if err := json.Unmarshal([]byte(content), &diagnosticNextResp); err == nil && diagnosticNextResp.ResponseType == "diagnostic" {
|
||||
fmt.Printf("🔄 TensorZero requests %d more commands\n", len(diagnosticNextResp.Commands))
|
||||
return map[string]interface{}{
|
||||
"type": "diagnostic",
|
||||
"response": diagnosticNextResp,
|
||||
"raw": content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Check if it's a resolution response
|
||||
if err := json.Unmarshal([]byte(content), &resolutionResp); err == nil && resolutionResp.ResponseType == "resolution" {
|
||||
|
||||
return map[string]interface{}{
|
||||
"type": "resolution",
|
||||
"response": resolutionResp,
|
||||
"raw": content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Return raw response if we can't parse it
|
||||
return map[string]interface{}{
|
||||
"type": "unknown",
|
||||
"raw": content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Helper function to marshal JSON without errors
|
||||
func mustMarshalJSON(v interface{}) string {
|
||||
data, _ := json.Marshal(v)
|
||||
return string(data)
|
||||
}
|
||||
|
||||
// processInvestigation handles the actual investigation using TensorZero
|
||||
// This endpoint receives either:
|
||||
// 1. DiagnosticResponse - Commands and eBPF programs to execute
|
||||
// 2. ResolutionResponse - Final resolution (no execution needed)
|
||||
func (s *InvestigationServer) handleInvestigation(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed - only POST accepted", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse the request body to determine what type of response this is
|
||||
var requestBody map[string]interface{}
|
||||
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Invalid JSON: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Check the response_type field to determine how to handle this
|
||||
responseType, ok := requestBody["response_type"].(string)
|
||||
if !ok {
|
||||
http.Error(w, "Missing or invalid response_type field", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("📋 Received investigation payload with response_type: %s\n", responseType)
|
||||
|
||||
switch responseType {
|
||||
case "diagnostic":
|
||||
// This is a DiagnosticResponse with commands to execute
|
||||
response := s.handleDiagnosticExecution(requestBody)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(response)
|
||||
|
||||
case "resolution":
|
||||
// This is a ResolutionResponse - final result, just acknowledge
|
||||
fmt.Printf("📋 Received final resolution from backend\n")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"success": true,
|
||||
"message": "Resolution received and acknowledged",
|
||||
"agent_id": s.agentID,
|
||||
})
|
||||
|
||||
default:
|
||||
http.Error(w, fmt.Sprintf("Unknown response_type: %s", responseType), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// handleDiagnosticExecution executes commands from a DiagnosticResponse
|
||||
func (s *InvestigationServer) handleDiagnosticExecution(requestBody map[string]interface{}) map[string]interface{} {
|
||||
// Parse as DiagnosticResponse
|
||||
var diagnosticResp DiagnosticResponse
|
||||
|
||||
// Convert the map back to JSON and then parse it properly
|
||||
jsonData, err := json.Marshal(requestBody)
|
||||
if err != nil {
|
||||
return map[string]interface{}{
|
||||
"success": false,
|
||||
"error": fmt.Sprintf("Failed to re-marshal request: %v", err),
|
||||
"agent_id": s.agentID,
|
||||
}
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(jsonData, &diagnosticResp); err != nil {
|
||||
return map[string]interface{}{
|
||||
"success": false,
|
||||
"error": fmt.Sprintf("Failed to parse DiagnosticResponse: %v", err),
|
||||
"agent_id": s.agentID,
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Printf("📋 Executing %d commands from backend\n", len(diagnosticResp.Commands))
|
||||
|
||||
// Execute all commands
|
||||
commandResults := make([]CommandResult, 0, len(diagnosticResp.Commands))
|
||||
|
||||
for _, cmd := range diagnosticResp.Commands {
|
||||
fmt.Printf("⚙️ Executing command '%s': %s\n", cmd.ID, cmd.Command)
|
||||
|
||||
// Use the agent's executor to run the command
|
||||
result := s.agent.executor.Execute(cmd)
|
||||
commandResults = append(commandResults, result)
|
||||
|
||||
|
||||
if result.Error != "" {
|
||||
fmt.Printf("⚠️ Command '%s' had error: %s\n", cmd.ID, result.Error)
|
||||
}
|
||||
}
|
||||
|
||||
// Send command results back to TensorZero for continued analysis
|
||||
fmt.Printf("🔄 Sending %d command results back to TensorZero for continued analysis\n", len(commandResults))
|
||||
|
||||
nextResponse, err := s.sendCommandResultsToTensorZero(diagnosticResp, commandResults)
|
||||
if err != nil {
|
||||
return map[string]interface{}{
|
||||
"success": false,
|
||||
"error": fmt.Sprintf("Failed to continue TensorZero conversation: %v", err),
|
||||
"agent_id": s.agentID,
|
||||
"command_results": commandResults, // Still return the results
|
||||
}
|
||||
}
|
||||
|
||||
// Return both the command results and the next response from TensorZero
|
||||
return map[string]interface{}{
|
||||
"success": true,
|
||||
"agent_id": s.agentID,
|
||||
"command_results": commandResults,
|
||||
"commands_executed": len(commandResults),
|
||||
"next_response": nextResponse,
|
||||
"timestamp": time.Now().Format(time.RFC3339),
|
||||
}
|
||||
}
|
||||
|
||||
// PendingInvestigation represents a pending investigation from the database
|
||||
type PendingInvestigation struct {
|
||||
ID string `json:"id"`
|
||||
InvestigationID string `json:"investigation_id"`
|
||||
AgentID string `json:"agent_id"`
|
||||
DiagnosticPayload map[string]interface{} `json:"diagnostic_payload"`
|
||||
EpisodeID *string `json:"episode_id"`
|
||||
Status string `json:"status"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// startRealtimePolling begins polling for pending investigations
|
||||
func (s *InvestigationServer) startRealtimePolling() {
|
||||
fmt.Printf("🔄 Starting realtime investigation polling for agent %s\n", s.agentID)
|
||||
|
||||
ticker := time.NewTicker(5 * time.Second) // Poll every 5 seconds
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
s.checkForPendingInvestigations()
|
||||
}
|
||||
}
|
||||
|
||||
// checkForPendingInvestigations checks for new pending investigations
|
||||
func (s *InvestigationServer) checkForPendingInvestigations() {
|
||||
url := fmt.Sprintf("%s/rest/v1/pending_investigations?agent_id=eq.%s&status=eq.pending&order=created_at.desc",
|
||||
s.supabaseURL, s.agentID)
|
||||
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
return // Silent fail for polling
|
||||
}
|
||||
|
||||
// Get token from auth manager
|
||||
authToken, err := s.authManager.LoadToken()
|
||||
if err != nil {
|
||||
return // Silent fail for polling
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", authToken.AccessToken))
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return // Silent fail for polling
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return // Silent fail for polling
|
||||
}
|
||||
|
||||
var investigations []PendingInvestigation
|
||||
err = json.NewDecoder(resp.Body).Decode(&investigations)
|
||||
if err != nil {
|
||||
return // Silent fail for polling
|
||||
}
|
||||
|
||||
for _, investigation := range investigations {
|
||||
fmt.Printf("🔍 Found pending investigation: %s\n", investigation.ID)
|
||||
go s.handlePendingInvestigation(investigation)
|
||||
}
|
||||
}
|
||||
|
||||
// handlePendingInvestigation processes a single pending investigation
|
||||
func (s *InvestigationServer) handlePendingInvestigation(investigation PendingInvestigation) {
|
||||
fmt.Printf("🚀 Processing realtime investigation %s\n", investigation.InvestigationID)
|
||||
|
||||
// Mark as executing
|
||||
err := s.updateInvestigationStatus(investigation.ID, "executing", nil, nil)
|
||||
if err != nil {
|
||||
fmt.Printf("❌ Failed to mark investigation as executing: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Execute diagnostic commands using existing handleDiagnosticExecution method
|
||||
results := s.handleDiagnosticExecution(investigation.DiagnosticPayload)
|
||||
|
||||
// Mark as completed with results
|
||||
err = s.updateInvestigationStatus(investigation.ID, "completed", results, nil)
|
||||
if err != nil {
|
||||
fmt.Printf("❌ Failed to mark investigation as completed: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
// updateInvestigationStatus updates the status of a pending investigation
|
||||
func (s *InvestigationServer) updateInvestigationStatus(id, status string, results map[string]interface{}, errorMsg *string) error {
|
||||
updateData := map[string]interface{}{
|
||||
"status": status,
|
||||
}
|
||||
|
||||
if status == "executing" {
|
||||
updateData["started_at"] = time.Now().UTC().Format(time.RFC3339)
|
||||
} else if status == "completed" {
|
||||
updateData["completed_at"] = time.Now().UTC().Format(time.RFC3339)
|
||||
if results != nil {
|
||||
updateData["command_results"] = results
|
||||
}
|
||||
} else if status == "failed" && errorMsg != nil {
|
||||
updateData["error_message"] = *errorMsg
|
||||
updateData["completed_at"] = time.Now().UTC().Format(time.RFC3339)
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(updateData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal update data: %v", err)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/rest/v1/pending_investigations?id=eq.%s", s.supabaseURL, id)
|
||||
req, err := http.NewRequest("PATCH", url, strings.NewReader(string(jsonData)))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %v", err)
|
||||
}
|
||||
|
||||
// Get token from auth manager
|
||||
authToken, err := s.authManager.LoadToken()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load auth token: %v", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", authToken.AccessToken))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update investigation: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 && resp.StatusCode != 204 {
|
||||
return fmt.Errorf("supabase update error: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
165
main.go
165
main.go
@@ -9,17 +9,25 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"nannyagentv2/internal/auth"
|
||||
"nannyagentv2/internal/config"
|
||||
"nannyagentv2/internal/metrics"
|
||||
"nannyagentv2/internal/types"
|
||||
)
|
||||
|
||||
const Version = "v2.0.0"
|
||||
|
||||
// checkRootPrivileges ensures the program is running as root
|
||||
func checkRootPrivileges() {
|
||||
if os.Geteuid() != 0 {
|
||||
fmt.Fprintf(os.Stderr, "❌ ERROR: This program must be run as root for eBPF functionality.\n")
|
||||
fmt.Fprintf(os.Stderr, "Please run with: sudo %s\n", os.Args[0])
|
||||
fmt.Fprintf(os.Stderr, "Reason: eBPF programs require root privileges to:\n")
|
||||
fmt.Fprintf(os.Stderr, " - Load programs into the kernel\n")
|
||||
fmt.Fprintf(os.Stderr, " - Attach to kernel functions and tracepoints\n")
|
||||
fmt.Fprintf(os.Stderr, " - Access kernel memory maps\n")
|
||||
fmt.Fprintf(os.Stderr, " - Load programs into the kernel\n")
|
||||
fmt.Fprintf(os.Stderr, " - Attach to kernel functions and tracepoints\n")
|
||||
fmt.Fprintf(os.Stderr, " - Access kernel memory maps\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
@@ -59,20 +67,20 @@ func checkKernelVersionCompatibility() {
|
||||
fmt.Fprintf(os.Stderr, "Required: Linux kernel 4.4 or higher\n")
|
||||
fmt.Fprintf(os.Stderr, "Current: %s\n", kernelVersion)
|
||||
fmt.Fprintf(os.Stderr, "Reason: eBPF requires kernel features introduced in 4.4+:\n")
|
||||
fmt.Fprintf(os.Stderr, " - BPF system call support\n")
|
||||
fmt.Fprintf(os.Stderr, " - eBPF program types (kprobe, tracepoint)\n")
|
||||
fmt.Fprintf(os.Stderr, " - BPF maps and helper functions\n")
|
||||
fmt.Fprintf(os.Stderr, " - BPF system call support\n")
|
||||
fmt.Fprintf(os.Stderr, " - eBPF program types (kprobe, tracepoint)\n")
|
||||
fmt.Fprintf(os.Stderr, " - BPF maps and helper functions\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
fmt.Printf("✅ Kernel version %s is compatible with eBPF\n", kernelVersion)
|
||||
|
||||
}
|
||||
|
||||
// checkEBPFSupport validates eBPF subsystem availability
|
||||
func checkEBPFSupport() {
|
||||
// Check if /sys/kernel/debug/tracing exists (debugfs mounted)
|
||||
if _, err := os.Stat("/sys/kernel/debug/tracing"); os.IsNotExist(err) {
|
||||
fmt.Fprintf(os.Stderr, "⚠️ WARNING: debugfs not mounted. Some eBPF features may not work.\n")
|
||||
fmt.Fprintf(os.Stderr, "⚠️ WARNING: debugfs not mounted. Some eBPF features may not work.\n")
|
||||
fmt.Fprintf(os.Stderr, "To fix: sudo mount -t debugfs debugfs /sys/kernel/debug\n")
|
||||
}
|
||||
|
||||
@@ -81,35 +89,22 @@ func checkEBPFSupport() {
|
||||
if errno != 0 && errno != syscall.EINVAL {
|
||||
fmt.Fprintf(os.Stderr, "❌ ERROR: BPF syscall not available (errno: %v)\n", errno)
|
||||
fmt.Fprintf(os.Stderr, "This may indicate:\n")
|
||||
fmt.Fprintf(os.Stderr, " - Kernel compiled without BPF support\n")
|
||||
fmt.Fprintf(os.Stderr, " - BPF syscall disabled in kernel config\n")
|
||||
fmt.Fprintf(os.Stderr, " - Kernel compiled without BPF support\n")
|
||||
fmt.Fprintf(os.Stderr, " - BPF syscall disabled in kernel config\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
if fd > 0 {
|
||||
syscall.Close(int(fd))
|
||||
}
|
||||
|
||||
fmt.Printf("✅ eBPF syscall is available\n")
|
||||
|
||||
}
|
||||
|
||||
func main() {
|
||||
// runInteractiveDiagnostics starts the interactive diagnostic session
|
||||
func runInteractiveDiagnostics(agent *LinuxDiagnosticAgent) {
|
||||
fmt.Println("")
|
||||
fmt.Println("🔍 Linux eBPF-Enhanced Diagnostic Agent")
|
||||
fmt.Println("=======================================")
|
||||
|
||||
// Perform system compatibility checks
|
||||
fmt.Println("Performing system compatibility checks...")
|
||||
|
||||
checkRootPrivileges()
|
||||
checkKernelVersionCompatibility()
|
||||
checkEBPFSupport()
|
||||
|
||||
fmt.Println("✅ All system checks passed")
|
||||
fmt.Println("")
|
||||
|
||||
// Initialize the agent
|
||||
agent := NewLinuxDiagnosticAgent()
|
||||
|
||||
// Start the interactive session
|
||||
fmt.Println("Linux Diagnostic Agent Started")
|
||||
fmt.Println("Enter a system issue description (or 'quit' to exit):")
|
||||
|
||||
@@ -129,8 +124,8 @@ func main() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Process the issue with eBPF capabilities
|
||||
if err := agent.DiagnoseWithEBPF(input); err != nil {
|
||||
// Process the issue with AI capabilities via TensorZero
|
||||
if err := agent.DiagnoseIssue(input); err != nil {
|
||||
fmt.Printf("Error: %v\n", err)
|
||||
}
|
||||
}
|
||||
@@ -141,3 +136,115 @@ func main() {
|
||||
|
||||
fmt.Println("Goodbye!")
|
||||
}
|
||||
|
||||
func main() {
|
||||
fmt.Printf("🚀 NannyAgent v%s starting...\n", Version)
|
||||
|
||||
// Perform system compatibility checks first
|
||||
fmt.Println("Performing system compatibility checks...")
|
||||
checkRootPrivileges()
|
||||
checkKernelVersionCompatibility()
|
||||
checkEBPFSupport()
|
||||
fmt.Println("✅ All system checks passed")
|
||||
fmt.Println("")
|
||||
|
||||
// Load configuration
|
||||
cfg, err := config.LoadConfig()
|
||||
if err != nil {
|
||||
log.Fatalf("❌ Failed to load configuration: %v", err)
|
||||
}
|
||||
|
||||
cfg.PrintConfig()
|
||||
|
||||
// Initialize components
|
||||
authManager := auth.NewAuthManager(cfg)
|
||||
metricsCollector := metrics.NewCollector(Version)
|
||||
|
||||
// Ensure authentication
|
||||
token, err := authManager.EnsureAuthenticated()
|
||||
if err != nil {
|
||||
log.Fatalf("❌ Authentication failed: %v", err)
|
||||
}
|
||||
|
||||
fmt.Println("✅ Authentication successful!")
|
||||
|
||||
// Initialize the diagnostic agent for interactive CLI use
|
||||
agent := NewLinuxDiagnosticAgent()
|
||||
|
||||
// Initialize a separate agent for WebSocket investigations using the application model
|
||||
applicationAgent := NewLinuxDiagnosticAgent()
|
||||
applicationAgent.model = "tensorzero::function_name::diagnose_and_heal_application"
|
||||
|
||||
// Start WebSocket client for backend communications and investigations
|
||||
wsClient := NewWebSocketClient(applicationAgent, authManager)
|
||||
go func() {
|
||||
if err := wsClient.Start(); err != nil {
|
||||
log.Printf("❌ WebSocket client error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Start background metrics collection in a goroutine
|
||||
go func() {
|
||||
fmt.Println("❤️ Starting background metrics collection and heartbeat...")
|
||||
|
||||
ticker := time.NewTicker(time.Duration(cfg.MetricsInterval) * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Send initial heartbeat
|
||||
if err := sendHeartbeat(cfg, token, metricsCollector); err != nil {
|
||||
log.Printf("⚠️ Initial heartbeat failed: %v", err)
|
||||
}
|
||||
|
||||
// Main heartbeat loop
|
||||
for range ticker.C {
|
||||
// Check if token needs refresh
|
||||
if authManager.IsTokenExpired(token) {
|
||||
fmt.Println("🔄 Token expiring soon, refreshing...")
|
||||
newToken, refreshErr := authManager.EnsureAuthenticated()
|
||||
if refreshErr != nil {
|
||||
log.Printf("❌ Token refresh failed: %v", refreshErr)
|
||||
continue
|
||||
}
|
||||
token = newToken
|
||||
fmt.Println("✅ Token refreshed successfully")
|
||||
}
|
||||
|
||||
// Send heartbeat
|
||||
if err := sendHeartbeat(cfg, token, metricsCollector); err != nil {
|
||||
log.Printf("⚠️ Heartbeat failed: %v", err)
|
||||
|
||||
// If unauthorized, try to refresh token
|
||||
if err.Error() == "unauthorized" {
|
||||
fmt.Println("🔄 Unauthorized, attempting token refresh...")
|
||||
newToken, refreshErr := authManager.EnsureAuthenticated()
|
||||
if refreshErr != nil {
|
||||
log.Printf("❌ Token refresh failed: %v", refreshErr)
|
||||
continue
|
||||
}
|
||||
token = newToken
|
||||
|
||||
// Retry heartbeat with new token (silently)
|
||||
if retryErr := sendHeartbeat(cfg, token, metricsCollector); retryErr != nil {
|
||||
log.Printf("⚠️ Retry heartbeat failed: %v", retryErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
// No logging for successful heartbeats - they should be silent
|
||||
}
|
||||
}()
|
||||
|
||||
// Start the interactive diagnostic session (blocking)
|
||||
runInteractiveDiagnostics(agent)
|
||||
}
|
||||
|
||||
// sendHeartbeat collects metrics and sends heartbeat to the server
|
||||
func sendHeartbeat(cfg *config.Config, token *types.AuthToken, collector *metrics.Collector) error {
|
||||
// Collect system metrics
|
||||
systemMetrics, err := collector.GatherSystemMetrics()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to gather system metrics: %w", err)
|
||||
}
|
||||
|
||||
// Send metrics using the collector with correct agent_id from token
|
||||
return collector.SendMetrics(cfg.AgentAuthURL, token.AccessToken, token.AgentID, systemMetrics)
|
||||
}
|
||||
|
||||
839
websocket_client.go
Normal file
839
websocket_client.go
Normal file
@@ -0,0 +1,839 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"nannyagentv2/internal/auth"
|
||||
"nannyagentv2/internal/metrics"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
// Helper function for minimum of two integers
|
||||
|
||||
// WebSocketMessage represents a message sent over WebSocket
|
||||
type WebSocketMessage struct {
|
||||
Type string `json:"type"`
|
||||
Data interface{} `json:"data"`
|
||||
}
|
||||
|
||||
// InvestigationTask represents a task sent to the agent
|
||||
type InvestigationTask struct {
|
||||
TaskID string `json:"task_id"`
|
||||
InvestigationID string `json:"investigation_id"`
|
||||
AgentID string `json:"agent_id"`
|
||||
DiagnosticPayload map[string]interface{} `json:"diagnostic_payload"`
|
||||
EpisodeID string `json:"episode_id,omitempty"`
|
||||
}
|
||||
|
||||
// TaskResult represents the result of a completed task
|
||||
type TaskResult struct {
|
||||
TaskID string `json:"task_id"`
|
||||
Success bool `json:"success"`
|
||||
CommandResults map[string]interface{} `json:"command_results,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// HeartbeatData represents heartbeat information
|
||||
type HeartbeatData struct {
|
||||
AgentID string `json:"agent_id"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
// WebSocketClient handles WebSocket connection to Supabase backend
|
||||
type WebSocketClient struct {
|
||||
agent *LinuxDiagnosticAgent
|
||||
conn *websocket.Conn
|
||||
agentID string
|
||||
authManager *auth.AuthManager
|
||||
metricsCollector *metrics.Collector
|
||||
supabaseURL string
|
||||
token string
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
consecutiveFailures int // Track consecutive connection failures
|
||||
}
|
||||
|
||||
// NewWebSocketClient creates a new WebSocket client
|
||||
func NewWebSocketClient(agent *LinuxDiagnosticAgent, authManager *auth.AuthManager) *WebSocketClient {
|
||||
// Get agent ID from authentication system
|
||||
var agentID string
|
||||
if authManager != nil {
|
||||
if id, err := authManager.GetCurrentAgentID(); err == nil {
|
||||
agentID = id
|
||||
// Agent ID retrieved successfully
|
||||
} else {
|
||||
fmt.Printf("❌ Failed to get agent ID from auth manager: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to environment variable or generate one if auth fails
|
||||
if agentID == "" {
|
||||
agentID = os.Getenv("AGENT_ID")
|
||||
if agentID == "" {
|
||||
agentID = fmt.Sprintf("agent-%d", time.Now().Unix())
|
||||
}
|
||||
}
|
||||
|
||||
supabaseURL := os.Getenv("SUPABASE_PROJECT_URL")
|
||||
if supabaseURL == "" {
|
||||
log.Fatal("❌ SUPABASE_PROJECT_URL environment variable is required")
|
||||
}
|
||||
|
||||
// Create metrics collector
|
||||
metricsCollector := metrics.NewCollector("v2.0.0")
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
return &WebSocketClient{
|
||||
agent: agent,
|
||||
agentID: agentID,
|
||||
authManager: authManager,
|
||||
metricsCollector: metricsCollector,
|
||||
supabaseURL: supabaseURL,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the WebSocket connection and message handling
|
||||
func (w *WebSocketClient) Start() error {
|
||||
// Starting WebSocket client
|
||||
|
||||
if err := w.connect(); err != nil {
|
||||
return fmt.Errorf("failed to establish WebSocket connection: %v", err)
|
||||
}
|
||||
|
||||
// Start message reading loop
|
||||
go w.handleMessages()
|
||||
|
||||
// Start heartbeat
|
||||
go w.startHeartbeat()
|
||||
|
||||
// Start database polling for pending investigations
|
||||
go w.pollPendingInvestigations()
|
||||
|
||||
// WebSocket client started
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop closes the WebSocket connection
|
||||
func (c *WebSocketClient) Stop() {
|
||||
c.cancel()
|
||||
if c.conn != nil {
|
||||
c.conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// getAuthToken retrieves authentication token
|
||||
func (c *WebSocketClient) getAuthToken() error {
|
||||
if c.authManager == nil {
|
||||
return fmt.Errorf("auth manager not available")
|
||||
}
|
||||
|
||||
token, err := c.authManager.EnsureAuthenticated()
|
||||
if err != nil {
|
||||
return fmt.Errorf("authentication failed: %v", err)
|
||||
}
|
||||
|
||||
c.token = token.AccessToken
|
||||
return nil
|
||||
}
|
||||
|
||||
// connect establishes WebSocket connection
|
||||
func (c *WebSocketClient) connect() error {
|
||||
// Get fresh auth token
|
||||
if err := c.getAuthToken(); err != nil {
|
||||
return fmt.Errorf("failed to get auth token: %v", err)
|
||||
}
|
||||
|
||||
// Convert HTTP URL to WebSocket URL
|
||||
wsURL := strings.Replace(c.supabaseURL, "https://", "wss://", 1)
|
||||
wsURL = strings.Replace(wsURL, "http://", "ws://", 1)
|
||||
wsURL += "/functions/v1/websocket-agent-handler"
|
||||
|
||||
// Connecting to WebSocket
|
||||
|
||||
// Set up headers
|
||||
headers := http.Header{}
|
||||
headers.Set("Authorization", "Bearer "+c.token)
|
||||
|
||||
// Connect
|
||||
dialer := websocket.Dialer{
|
||||
HandshakeTimeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
conn, resp, err := dialer.Dial(wsURL, headers)
|
||||
if err != nil {
|
||||
c.consecutiveFailures++
|
||||
if c.consecutiveFailures >= 5 && resp != nil {
|
||||
fmt.Printf("❌ WebSocket handshake failed with status: %d (failure #%d)\n", resp.StatusCode, c.consecutiveFailures)
|
||||
}
|
||||
return fmt.Errorf("websocket connection failed: %v", err)
|
||||
}
|
||||
|
||||
c.conn = conn
|
||||
// WebSocket client connected
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleMessages processes incoming WebSocket messages
|
||||
func (c *WebSocketClient) handleMessages() {
|
||||
defer func() {
|
||||
if c.conn != nil {
|
||||
// Closing WebSocket connection
|
||||
c.conn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
// Started WebSocket message listener
|
||||
connectionStart := time.Now()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
// Only log context cancellation if there have been failures
|
||||
if c.consecutiveFailures >= 5 {
|
||||
fmt.Printf("📡 Context cancelled after %v, stopping message handler\n", time.Since(connectionStart))
|
||||
}
|
||||
return
|
||||
default:
|
||||
// Set read deadline to detect connection issues
|
||||
c.conn.SetReadDeadline(time.Now().Add(90 * time.Second))
|
||||
|
||||
var message WebSocketMessage
|
||||
readStart := time.Now()
|
||||
err := c.conn.ReadJSON(&message)
|
||||
readDuration := time.Since(readStart)
|
||||
|
||||
if err != nil {
|
||||
connectionDuration := time.Since(connectionStart)
|
||||
|
||||
// Only log specific errors after failure threshold
|
||||
if c.consecutiveFailures >= 5 {
|
||||
if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
|
||||
log.Printf("🔒 WebSocket closed normally after %v: %v", connectionDuration, err)
|
||||
} else if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||||
log.Printf("💥 ABNORMAL CLOSE after %v (code 1006 = server-side timeout/kill): %v", connectionDuration, err)
|
||||
log.Printf("🕒 Last read took %v, connection lived %v", readDuration, connectionDuration)
|
||||
} else if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
log.Printf("⏰ READ TIMEOUT after %v: %v", connectionDuration, err)
|
||||
} else {
|
||||
log.Printf("❌ WebSocket error after %v: %v", connectionDuration, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Track consecutive failures for diagnostic threshold
|
||||
c.consecutiveFailures++
|
||||
|
||||
// Only show diagnostics after multiple failures
|
||||
if c.consecutiveFailures >= 5 {
|
||||
log.Printf("🔍 DIAGNOSTIC - Connection failed #%d after %v", c.consecutiveFailures, connectionDuration)
|
||||
}
|
||||
|
||||
// Attempt reconnection instead of returning immediately
|
||||
go c.attemptReconnection()
|
||||
return
|
||||
}
|
||||
|
||||
// Received WebSocket message successfully - reset failure counter
|
||||
c.consecutiveFailures = 0
|
||||
|
||||
switch message.Type {
|
||||
case "connection_ack":
|
||||
// Connection acknowledged
|
||||
|
||||
case "heartbeat_ack":
|
||||
// Heartbeat acknowledged
|
||||
|
||||
case "investigation_task":
|
||||
// Received investigation task - processing
|
||||
go c.handleInvestigationTask(message.Data)
|
||||
|
||||
case "task_result_ack":
|
||||
// Task result acknowledged
|
||||
|
||||
default:
|
||||
log.Printf("⚠️ Unknown message type: %s", message.Type)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleInvestigationTask processes investigation tasks from the backend
|
||||
func (c *WebSocketClient) handleInvestigationTask(data interface{}) {
|
||||
// Parse task data
|
||||
taskBytes, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
log.Printf("❌ Error marshaling task data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var task InvestigationTask
|
||||
err = json.Unmarshal(taskBytes, &task)
|
||||
if err != nil {
|
||||
log.Printf("❌ Error unmarshaling investigation task: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Processing investigation task
|
||||
|
||||
// Execute diagnostic commands
|
||||
results, err := c.executeDiagnosticCommands(task.DiagnosticPayload)
|
||||
|
||||
// Prepare task result
|
||||
taskResult := TaskResult{
|
||||
TaskID: task.TaskID,
|
||||
Success: err == nil,
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
taskResult.Error = err.Error()
|
||||
fmt.Printf("❌ Task execution failed: %v\n", err)
|
||||
} else {
|
||||
taskResult.CommandResults = results
|
||||
// Task executed successfully
|
||||
}
|
||||
|
||||
// Send result back
|
||||
c.sendTaskResult(taskResult)
|
||||
}
|
||||
|
||||
// executeDiagnosticCommands executes the commands from a diagnostic response
|
||||
func (c *WebSocketClient) executeDiagnosticCommands(diagnosticPayload map[string]interface{}) (map[string]interface{}, error) {
|
||||
results := map[string]interface{}{
|
||||
"agent_id": c.agentID,
|
||||
"execution_time": time.Now().UTC().Format(time.RFC3339),
|
||||
"command_results": []map[string]interface{}{},
|
||||
}
|
||||
|
||||
// Extract commands from diagnostic payload
|
||||
commands, ok := diagnosticPayload["commands"].([]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no commands found in diagnostic payload")
|
||||
}
|
||||
|
||||
var commandResults []map[string]interface{}
|
||||
|
||||
for _, cmd := range commands {
|
||||
cmdMap, ok := cmd.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
id, _ := cmdMap["id"].(string)
|
||||
command, _ := cmdMap["command"].(string)
|
||||
description, _ := cmdMap["description"].(string)
|
||||
|
||||
if command == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Executing command
|
||||
|
||||
// Execute the command
|
||||
output, exitCode, err := c.executeCommand(command)
|
||||
|
||||
result := map[string]interface{}{
|
||||
"id": id,
|
||||
"command": command,
|
||||
"description": description,
|
||||
"output": output,
|
||||
"exit_code": exitCode,
|
||||
"success": err == nil && exitCode == 0,
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
result["error"] = err.Error()
|
||||
fmt.Printf("❌ Command [%s] failed: %v (exit code: %d)\n", id, err, exitCode)
|
||||
}
|
||||
|
||||
commandResults = append(commandResults, result)
|
||||
}
|
||||
|
||||
results["command_results"] = commandResults
|
||||
results["total_commands"] = len(commandResults)
|
||||
results["successful_commands"] = c.countSuccessfulCommands(commandResults)
|
||||
|
||||
// Execute eBPF programs if present
|
||||
ebpfPrograms, hasEBPF := diagnosticPayload["ebpf_programs"].([]interface{})
|
||||
if hasEBPF && len(ebpfPrograms) > 0 {
|
||||
ebpfResults := c.executeEBPFPrograms(ebpfPrograms)
|
||||
results["ebpf_results"] = ebpfResults
|
||||
results["total_ebpf_programs"] = len(ebpfPrograms)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// executeEBPFPrograms executes eBPF monitoring programs using the real eBPF manager
|
||||
func (c *WebSocketClient) executeEBPFPrograms(ebpfPrograms []interface{}) []map[string]interface{} {
|
||||
var ebpfRequests []EBPFRequest
|
||||
|
||||
// Convert interface{} to EBPFRequest structs
|
||||
for _, prog := range ebpfPrograms {
|
||||
progMap, ok := prog.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
name, _ := progMap["name"].(string)
|
||||
progType, _ := progMap["type"].(string)
|
||||
target, _ := progMap["target"].(string)
|
||||
duration, _ := progMap["duration"].(float64)
|
||||
description, _ := progMap["description"].(string)
|
||||
|
||||
if name == "" || progType == "" || target == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
ebpfRequests = append(ebpfRequests, EBPFRequest{
|
||||
Name: name,
|
||||
Type: progType,
|
||||
Target: target,
|
||||
Duration: int(duration),
|
||||
Description: description,
|
||||
})
|
||||
}
|
||||
|
||||
// Execute eBPF programs using the agent's eBPF execution logic
|
||||
return c.agent.executeEBPFPrograms(ebpfRequests)
|
||||
}
|
||||
|
||||
// executeCommandsFromPayload executes commands from a payload and returns results
|
||||
func (c *WebSocketClient) executeCommandsFromPayload(commands []interface{}) []map[string]interface{} {
|
||||
var commandResults []map[string]interface{}
|
||||
|
||||
for _, cmd := range commands {
|
||||
cmdMap, ok := cmd.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
id, _ := cmdMap["id"].(string)
|
||||
command, _ := cmdMap["command"].(string)
|
||||
description, _ := cmdMap["description"].(string)
|
||||
|
||||
if command == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Execute the command
|
||||
output, exitCode, err := c.executeCommand(command)
|
||||
|
||||
result := map[string]interface{}{
|
||||
"id": id,
|
||||
"command": command,
|
||||
"description": description,
|
||||
"output": output,
|
||||
"exit_code": exitCode,
|
||||
"success": err == nil && exitCode == 0,
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
result["error"] = err.Error()
|
||||
fmt.Printf("❌ Command [%s] failed: %v (exit code: %d)\n", id, err, exitCode)
|
||||
}
|
||||
|
||||
commandResults = append(commandResults, result)
|
||||
}
|
||||
|
||||
return commandResults
|
||||
}
|
||||
|
||||
// executeCommand executes a shell command and returns output, exit code, and error
|
||||
func (c *WebSocketClient) executeCommand(command string) (string, int, error) {
|
||||
// Parse command into parts
|
||||
parts := strings.Fields(command)
|
||||
if len(parts) == 0 {
|
||||
return "", -1, fmt.Errorf("empty command")
|
||||
}
|
||||
|
||||
// Create command with timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, parts[0], parts[1:]...)
|
||||
cmd.Env = os.Environ()
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
exitCode := 0
|
||||
|
||||
if err != nil {
|
||||
if exitError, ok := err.(*exec.ExitError); ok {
|
||||
exitCode = exitError.ExitCode()
|
||||
} else {
|
||||
exitCode = -1
|
||||
}
|
||||
}
|
||||
|
||||
return string(output), exitCode, err
|
||||
}
|
||||
|
||||
// countSuccessfulCommands counts the number of successful commands
|
||||
func (c *WebSocketClient) countSuccessfulCommands(results []map[string]interface{}) int {
|
||||
count := 0
|
||||
for _, result := range results {
|
||||
if success, ok := result["success"].(bool); ok && success {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// sendTaskResult sends a task result back to the backend
|
||||
func (c *WebSocketClient) sendTaskResult(result TaskResult) {
|
||||
message := WebSocketMessage{
|
||||
Type: "task_result",
|
||||
Data: result,
|
||||
}
|
||||
|
||||
err := c.conn.WriteJSON(message)
|
||||
if err != nil {
|
||||
log.Printf("❌ Error sending task result: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// startHeartbeat sends periodic heartbeat messages
|
||||
func (c *WebSocketClient) startHeartbeat() {
|
||||
ticker := time.NewTicker(30 * time.Second) // Heartbeat every 30 seconds
|
||||
defer ticker.Stop()
|
||||
|
||||
// Starting heartbeat
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
fmt.Printf("💓 Heartbeat stopped due to context cancellation\n")
|
||||
return
|
||||
case <-ticker.C:
|
||||
// Sending heartbeat
|
||||
heartbeat := WebSocketMessage{
|
||||
Type: "heartbeat",
|
||||
Data: HeartbeatData{
|
||||
AgentID: c.agentID,
|
||||
Timestamp: time.Now(),
|
||||
Version: "v2.0.0",
|
||||
},
|
||||
}
|
||||
|
||||
err := c.conn.WriteJSON(heartbeat)
|
||||
if err != nil {
|
||||
log.Printf("❌ Error sending heartbeat: %v", err)
|
||||
fmt.Printf("💓 Heartbeat failed, connection likely dead\n")
|
||||
return
|
||||
}
|
||||
// Heartbeat sent
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// pollPendingInvestigations polls the database for pending investigations
|
||||
func (c *WebSocketClient) pollPendingInvestigations() {
|
||||
// Starting database polling
|
||||
ticker := time.NewTicker(5 * time.Second) // Poll every 5 seconds
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
c.checkForPendingInvestigations()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// checkForPendingInvestigations checks the database for new pending investigations via proxy
|
||||
func (c *WebSocketClient) checkForPendingInvestigations() {
|
||||
// Use Edge Function proxy instead of direct database access
|
||||
url := fmt.Sprintf("%s/functions/v1/agent-database-proxy/pending-investigations", c.supabaseURL)
|
||||
|
||||
// Poll database for pending investigations
|
||||
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
// Request creation failed
|
||||
return
|
||||
}
|
||||
|
||||
// Only JWT token needed for proxy - no API keys exposed
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.token))
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
// Database request failed
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return
|
||||
}
|
||||
|
||||
var investigations []PendingInvestigation
|
||||
err = json.NewDecoder(resp.Body).Decode(&investigations)
|
||||
if err != nil {
|
||||
// Response decode failed
|
||||
return
|
||||
}
|
||||
|
||||
for _, investigation := range investigations {
|
||||
go c.handlePendingInvestigation(investigation)
|
||||
}
|
||||
}
|
||||
|
||||
// handlePendingInvestigation processes a pending investigation from database polling
|
||||
func (c *WebSocketClient) handlePendingInvestigation(investigation PendingInvestigation) {
|
||||
// Processing pending investigation
|
||||
|
||||
// Mark as executing
|
||||
err := c.updateInvestigationStatus(investigation.ID, "executing", nil, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Execute diagnostic commands
|
||||
results, err := c.executeDiagnosticCommands(investigation.DiagnosticPayload)
|
||||
|
||||
// Prepare the base results map we'll send to DB
|
||||
resultsForDB := map[string]interface{}{
|
||||
"agent_id": c.agentID,
|
||||
"execution_time": time.Now().UTC().Format(time.RFC3339),
|
||||
"command_results": results,
|
||||
}
|
||||
|
||||
// If command execution failed, mark investigation as failed
|
||||
if err != nil {
|
||||
errorMsg := err.Error()
|
||||
// Include partial results when possible
|
||||
if results != nil {
|
||||
resultsForDB["command_results"] = results
|
||||
}
|
||||
c.updateInvestigationStatus(investigation.ID, "failed", resultsForDB, &errorMsg)
|
||||
// Investigation failed
|
||||
return
|
||||
}
|
||||
|
||||
// Try to continue the TensorZero conversation by sending command results back
|
||||
// Build messages: assistant = diagnostic payload, user = command results
|
||||
diagJSON, _ := json.Marshal(investigation.DiagnosticPayload)
|
||||
commandsJSON, _ := json.MarshalIndent(results, "", " ")
|
||||
|
||||
messages := []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleAssistant,
|
||||
Content: string(diagJSON),
|
||||
},
|
||||
{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: string(commandsJSON),
|
||||
},
|
||||
}
|
||||
|
||||
// Use the episode ID from the investigation to maintain conversation continuity
|
||||
episodeID := ""
|
||||
if investigation.EpisodeID != nil {
|
||||
episodeID = *investigation.EpisodeID
|
||||
}
|
||||
|
||||
// Continue conversation until resolution (same as agent)
|
||||
var finalAIContent string
|
||||
for {
|
||||
tzResp, tzErr := c.agent.sendRequestWithEpisode(messages, episodeID)
|
||||
if tzErr != nil {
|
||||
fmt.Printf("⚠️ TensorZero continuation failed: %v\n", tzErr)
|
||||
// Fall back to marking completed with command results only
|
||||
c.updateInvestigationStatus(investigation.ID, "completed", resultsForDB, nil)
|
||||
return
|
||||
}
|
||||
|
||||
if len(tzResp.Choices) == 0 {
|
||||
fmt.Printf("⚠️ No choices in TensorZero response\n")
|
||||
c.updateInvestigationStatus(investigation.ID, "completed", resultsForDB, nil)
|
||||
return
|
||||
}
|
||||
|
||||
aiContent := tzResp.Choices[0].Message.Content
|
||||
if len(aiContent) > 300 {
|
||||
// AI response received successfully
|
||||
} else {
|
||||
fmt.Printf("🤖 AI Response: %s\n", aiContent)
|
||||
}
|
||||
|
||||
// Check if this is a resolution response (final)
|
||||
var resolutionResp struct {
|
||||
ResponseType string `json:"response_type"`
|
||||
RootCause string `json:"root_cause"`
|
||||
ResolutionPlan string `json:"resolution_plan"`
|
||||
Confidence string `json:"confidence"`
|
||||
}
|
||||
|
||||
fmt.Printf("🔍 Analyzing AI response type...\n")
|
||||
|
||||
if err := json.Unmarshal([]byte(aiContent), &resolutionResp); err == nil && resolutionResp.ResponseType == "resolution" {
|
||||
// This is the final resolution - show summary and complete
|
||||
fmt.Printf("\n=== DIAGNOSIS COMPLETE ===\n")
|
||||
fmt.Printf("Root Cause: %s\n", resolutionResp.RootCause)
|
||||
fmt.Printf("Resolution Plan: %s\n", resolutionResp.ResolutionPlan)
|
||||
fmt.Printf("Confidence: %s\n", resolutionResp.Confidence)
|
||||
finalAIContent = aiContent
|
||||
break
|
||||
}
|
||||
|
||||
// Check if this is another diagnostic response requiring more commands
|
||||
var diagnosticResp struct {
|
||||
ResponseType string `json:"response_type"`
|
||||
Commands []interface{} `json:"commands"`
|
||||
EBPFPrograms []interface{} `json:"ebpf_programs"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(aiContent), &diagnosticResp); err == nil && diagnosticResp.ResponseType == "diagnostic" {
|
||||
fmt.Printf("🔄 AI requested additional diagnostics, executing...\n")
|
||||
|
||||
// Execute additional commands if any
|
||||
additionalResults := map[string]interface{}{
|
||||
"command_results": []map[string]interface{}{},
|
||||
}
|
||||
|
||||
if len(diagnosticResp.Commands) > 0 {
|
||||
fmt.Printf("🔧 Executing %d additional diagnostic commands...\n", len(diagnosticResp.Commands))
|
||||
commandResults := c.executeCommandsFromPayload(diagnosticResp.Commands)
|
||||
additionalResults["command_results"] = commandResults
|
||||
}
|
||||
|
||||
// Execute additional eBPF programs if any
|
||||
if len(diagnosticResp.EBPFPrograms) > 0 {
|
||||
ebpfResults := c.executeEBPFPrograms(diagnosticResp.EBPFPrograms)
|
||||
additionalResults["ebpf_results"] = ebpfResults
|
||||
}
|
||||
|
||||
// Add AI response and additional results to conversation
|
||||
messages = append(messages, openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleAssistant,
|
||||
Content: aiContent,
|
||||
})
|
||||
|
||||
additionalResultsJSON, _ := json.MarshalIndent(additionalResults, "", " ")
|
||||
messages = append(messages, openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: string(additionalResultsJSON),
|
||||
})
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
// If neither resolution nor diagnostic, treat as final response
|
||||
fmt.Printf("⚠️ Unknown response type - treating as final response\n")
|
||||
finalAIContent = aiContent
|
||||
break
|
||||
}
|
||||
|
||||
// Attach final AI response to results for DB and mark as completed_with_analysis
|
||||
resultsForDB["ai_response"] = finalAIContent
|
||||
c.updateInvestigationStatus(investigation.ID, "completed_with_analysis", resultsForDB, nil)
|
||||
}
|
||||
|
||||
// updateInvestigationStatus updates the status of a pending investigation
|
||||
func (c *WebSocketClient) updateInvestigationStatus(id, status string, results map[string]interface{}, errorMsg *string) error {
|
||||
updateData := map[string]interface{}{
|
||||
"status": status,
|
||||
}
|
||||
|
||||
if status == "executing" {
|
||||
updateData["started_at"] = time.Now().UTC().Format(time.RFC3339)
|
||||
} else if status == "completed" {
|
||||
updateData["completed_at"] = time.Now().UTC().Format(time.RFC3339)
|
||||
if results != nil {
|
||||
updateData["command_results"] = results
|
||||
}
|
||||
} else if status == "failed" && errorMsg != nil {
|
||||
updateData["error_message"] = *errorMsg
|
||||
updateData["completed_at"] = time.Now().UTC().Format(time.RFC3339)
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(updateData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal update data: %v", err)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/functions/v1/agent-database-proxy/pending-investigations/%s", c.supabaseURL, id)
|
||||
req, err := http.NewRequest("PATCH", url, strings.NewReader(string(jsonData)))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %v", err)
|
||||
}
|
||||
|
||||
// Only JWT token needed for proxy - no API keys exposed
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.token))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update investigation: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 && resp.StatusCode != 204 {
|
||||
return fmt.Errorf("supabase update error: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// attemptReconnection attempts to reconnect the WebSocket with backoff
|
||||
func (c *WebSocketClient) attemptReconnection() {
|
||||
backoffDurations := []time.Duration{
|
||||
2 * time.Second,
|
||||
5 * time.Second,
|
||||
10 * time.Second,
|
||||
20 * time.Second,
|
||||
30 * time.Second,
|
||||
}
|
||||
|
||||
for i, backoff := range backoffDurations {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
default:
|
||||
c.consecutiveFailures++
|
||||
|
||||
// Only show messages after 5 consecutive failures
|
||||
if c.consecutiveFailures >= 5 {
|
||||
log.Printf("🔄 Attempting WebSocket reconnection (attempt %d/%d) - %d consecutive failures", i+1, len(backoffDurations), c.consecutiveFailures)
|
||||
}
|
||||
|
||||
time.Sleep(backoff)
|
||||
|
||||
if err := c.connect(); err != nil {
|
||||
if c.consecutiveFailures >= 5 {
|
||||
log.Printf("❌ Reconnection attempt %d failed: %v", i+1, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Successfully reconnected - reset failure counter
|
||||
if c.consecutiveFailures >= 5 {
|
||||
log.Printf("✅ WebSocket reconnected successfully after %d failures", c.consecutiveFailures)
|
||||
}
|
||||
c.consecutiveFailures = 0
|
||||
go c.handleMessages() // Restart message handling
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("❌ Failed to reconnect after %d attempts, giving up", len(backoffDurations))
|
||||
}
|
||||
Reference in New Issue
Block a user