Somewhat okay refactoring
This commit is contained in:
343
internal/ebpf/ebpf_event_parser.go
Normal file
343
internal/ebpf/ebpf_event_parser.go
Normal file
@@ -0,0 +1,343 @@
|
||||
package ebpf
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// EventScanner parses bpftrace output and converts it to TraceEvent structs
|
||||
type EventScanner struct {
|
||||
scanner *bufio.Scanner
|
||||
lastEvent *TraceEvent
|
||||
lineRegex *regexp.Regexp
|
||||
}
|
||||
|
||||
// NewEventScanner creates a new event scanner for parsing bpftrace output
|
||||
func NewEventScanner(reader io.Reader) *EventScanner {
|
||||
// Regex pattern to match our trace output format:
|
||||
// TRACE|timestamp|pid|tid|comm|function|message
|
||||
pattern := `^TRACE\|(\d+)\|(\d+)\|(\d+)\|([^|]+)\|([^|]+)\|(.*)$`
|
||||
regex, _ := regexp.Compile(pattern)
|
||||
|
||||
return &EventScanner{
|
||||
scanner: bufio.NewScanner(reader),
|
||||
lineRegex: regex,
|
||||
}
|
||||
}
|
||||
|
||||
// Scan advances the scanner to the next event
|
||||
func (es *EventScanner) Scan() bool {
|
||||
for es.scanner.Scan() {
|
||||
line := strings.TrimSpace(es.scanner.Text())
|
||||
|
||||
// Skip empty lines and non-trace lines
|
||||
if line == "" || !strings.HasPrefix(line, "TRACE|") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse the trace line
|
||||
if event := es.parseLine(line); event != nil {
|
||||
es.lastEvent = event
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Event returns the most recently parsed event
|
||||
func (es *EventScanner) Event() *TraceEvent {
|
||||
return es.lastEvent
|
||||
}
|
||||
|
||||
// Error returns any scanning error
|
||||
func (es *EventScanner) Error() error {
|
||||
return es.scanner.Err()
|
||||
}
|
||||
|
||||
// parseLine parses a single trace line into a TraceEvent
|
||||
func (es *EventScanner) parseLine(line string) *TraceEvent {
|
||||
matches := es.lineRegex.FindStringSubmatch(line)
|
||||
if len(matches) != 7 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse timestamp (nanoseconds)
|
||||
timestamp, err := strconv.ParseInt(matches[1], 10, 64)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse PID
|
||||
pid, err := strconv.Atoi(matches[2])
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse TID
|
||||
tid, err := strconv.Atoi(matches[3])
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Extract process name, function, and message
|
||||
processName := strings.TrimSpace(matches[4])
|
||||
function := strings.TrimSpace(matches[5])
|
||||
message := strings.TrimSpace(matches[6])
|
||||
|
||||
event := &TraceEvent{
|
||||
Timestamp: timestamp,
|
||||
PID: pid,
|
||||
TID: tid,
|
||||
ProcessName: processName,
|
||||
Function: function,
|
||||
Message: message,
|
||||
RawArgs: make(map[string]string),
|
||||
}
|
||||
|
||||
// Try to extract additional information from the message
|
||||
es.enrichEvent(event, message)
|
||||
|
||||
return event
|
||||
}
|
||||
|
||||
// enrichEvent extracts additional information from the message
|
||||
func (es *EventScanner) enrichEvent(event *TraceEvent, message string) {
|
||||
// Parse common patterns in messages to extract arguments
|
||||
// This is a simplified version - in a real implementation you'd want more sophisticated parsing
|
||||
|
||||
// Look for patterns like "arg1=value, arg2=value"
|
||||
argPattern := regexp.MustCompile(`(\w+)=([^,\s]+)`)
|
||||
matches := argPattern.FindAllStringSubmatch(message, -1)
|
||||
|
||||
for _, match := range matches {
|
||||
if len(match) == 3 {
|
||||
event.RawArgs[match[1]] = match[2]
|
||||
}
|
||||
}
|
||||
|
||||
// Look for numeric patterns that might be syscall arguments
|
||||
numberPattern := regexp.MustCompile(`\b(\d+)\b`)
|
||||
numbers := numberPattern.FindAllString(message, -1)
|
||||
|
||||
for i, num := range numbers {
|
||||
argName := "arg" + strconv.Itoa(i+1)
|
||||
event.RawArgs[argName] = num
|
||||
}
|
||||
}
|
||||
|
||||
// TraceEventFilter provides filtering capabilities for trace events
|
||||
type TraceEventFilter struct {
|
||||
MinTimestamp int64
|
||||
MaxTimestamp int64
|
||||
ProcessNames []string
|
||||
PIDs []int
|
||||
UIDs []int
|
||||
Functions []string
|
||||
MessageFilter string
|
||||
}
|
||||
|
||||
// ApplyFilter applies filters to a slice of events
|
||||
func (filter *TraceEventFilter) ApplyFilter(events []TraceEvent) []TraceEvent {
|
||||
if filter == nil {
|
||||
return events
|
||||
}
|
||||
|
||||
var filtered []TraceEvent
|
||||
|
||||
for _, event := range events {
|
||||
if filter.matchesEvent(&event) {
|
||||
filtered = append(filtered, event)
|
||||
}
|
||||
}
|
||||
|
||||
return filtered
|
||||
}
|
||||
|
||||
// matchesEvent checks if an event matches the filter criteria
|
||||
func (filter *TraceEventFilter) matchesEvent(event *TraceEvent) bool {
|
||||
// Check timestamp range
|
||||
if filter.MinTimestamp > 0 && event.Timestamp < filter.MinTimestamp {
|
||||
return false
|
||||
}
|
||||
if filter.MaxTimestamp > 0 && event.Timestamp > filter.MaxTimestamp {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check process names
|
||||
if len(filter.ProcessNames) > 0 {
|
||||
found := false
|
||||
for _, name := range filter.ProcessNames {
|
||||
if strings.Contains(event.ProcessName, name) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Check PIDs
|
||||
if len(filter.PIDs) > 0 {
|
||||
found := false
|
||||
for _, pid := range filter.PIDs {
|
||||
if event.PID == pid {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Check UIDs
|
||||
if len(filter.UIDs) > 0 {
|
||||
found := false
|
||||
for _, uid := range filter.UIDs {
|
||||
if event.UID == uid {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Check functions
|
||||
if len(filter.Functions) > 0 {
|
||||
found := false
|
||||
for _, function := range filter.Functions {
|
||||
if strings.Contains(event.Function, function) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Check message filter
|
||||
if filter.MessageFilter != "" {
|
||||
if !strings.Contains(event.Message, filter.MessageFilter) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// TraceEventAggregator provides aggregation capabilities for trace events
|
||||
type TraceEventAggregator struct {
|
||||
events []TraceEvent
|
||||
}
|
||||
|
||||
// NewTraceEventAggregator creates a new event aggregator
|
||||
func NewTraceEventAggregator(events []TraceEvent) *TraceEventAggregator {
|
||||
return &TraceEventAggregator{
|
||||
events: events,
|
||||
}
|
||||
}
|
||||
|
||||
// CountByProcess returns event counts grouped by process
|
||||
func (agg *TraceEventAggregator) CountByProcess() map[string]int {
|
||||
counts := make(map[string]int)
|
||||
for _, event := range agg.events {
|
||||
counts[event.ProcessName]++
|
||||
}
|
||||
return counts
|
||||
}
|
||||
|
||||
// CountByFunction returns event counts grouped by function
|
||||
func (agg *TraceEventAggregator) CountByFunction() map[string]int {
|
||||
counts := make(map[string]int)
|
||||
for _, event := range agg.events {
|
||||
counts[event.Function]++
|
||||
}
|
||||
return counts
|
||||
}
|
||||
|
||||
// CountByPID returns event counts grouped by PID
|
||||
func (agg *TraceEventAggregator) CountByPID() map[int]int {
|
||||
counts := make(map[int]int)
|
||||
for _, event := range agg.events {
|
||||
counts[event.PID]++
|
||||
}
|
||||
return counts
|
||||
}
|
||||
|
||||
// GetTimeRange returns the time range of events
|
||||
func (agg *TraceEventAggregator) GetTimeRange() (int64, int64) {
|
||||
if len(agg.events) == 0 {
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
minTime := agg.events[0].Timestamp
|
||||
maxTime := agg.events[0].Timestamp
|
||||
|
||||
for _, event := range agg.events {
|
||||
if event.Timestamp < minTime {
|
||||
minTime = event.Timestamp
|
||||
}
|
||||
if event.Timestamp > maxTime {
|
||||
maxTime = event.Timestamp
|
||||
}
|
||||
}
|
||||
|
||||
return minTime, maxTime
|
||||
}
|
||||
|
||||
// GetEventRate calculates events per second
|
||||
func (agg *TraceEventAggregator) GetEventRate() float64 {
|
||||
if len(agg.events) < 2 {
|
||||
return 0
|
||||
}
|
||||
|
||||
minTime, maxTime := agg.GetTimeRange()
|
||||
durationNs := maxTime - minTime
|
||||
durationSeconds := float64(durationNs) / float64(time.Second)
|
||||
|
||||
if durationSeconds == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
return float64(len(agg.events)) / durationSeconds
|
||||
}
|
||||
|
||||
// GetTopProcesses returns the most active processes
|
||||
func (agg *TraceEventAggregator) GetTopProcesses(limit int) []ProcessStat {
|
||||
counts := agg.CountByProcess()
|
||||
total := len(agg.events)
|
||||
|
||||
var stats []ProcessStat
|
||||
for processName, count := range counts {
|
||||
percentage := float64(count) / float64(total) * 100
|
||||
stats = append(stats, ProcessStat{
|
||||
ProcessName: processName,
|
||||
EventCount: count,
|
||||
Percentage: percentage,
|
||||
})
|
||||
}
|
||||
|
||||
// Simple sorting by event count (bubble sort for simplicity)
|
||||
for i := 0; i < len(stats); i++ {
|
||||
for j := i + 1; j < len(stats); j++ {
|
||||
if stats[j].EventCount > stats[i].EventCount {
|
||||
stats[i], stats[j] = stats[j], stats[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if limit > 0 && limit < len(stats) {
|
||||
stats = stats[:limit]
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
587
internal/ebpf/ebpf_trace_manager.go
Normal file
587
internal/ebpf/ebpf_trace_manager.go
Normal file
@@ -0,0 +1,587 @@
|
||||
package ebpf
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"nannyagentv2/internal/logging"
|
||||
)
|
||||
|
||||
// TraceSpec represents a trace specification similar to BCC trace.py
|
||||
type TraceSpec struct {
|
||||
// Probe type: "p" (kprobe), "r" (kretprobe), "t" (tracepoint), "u" (uprobe)
|
||||
ProbeType string `json:"probe_type"`
|
||||
|
||||
// Target function/syscall/tracepoint
|
||||
Target string `json:"target"`
|
||||
|
||||
// Library for userspace probes (empty for kernel)
|
||||
Library string `json:"library,omitempty"`
|
||||
|
||||
// Format string for output (e.g., "read %d bytes", arg3)
|
||||
Format string `json:"format"`
|
||||
|
||||
// Arguments to extract (e.g., ["arg1", "arg2", "retval"])
|
||||
Arguments []string `json:"arguments"`
|
||||
|
||||
// Filter condition (e.g., "arg3 > 20000")
|
||||
Filter string `json:"filter,omitempty"`
|
||||
|
||||
// Duration in seconds
|
||||
Duration int `json:"duration"`
|
||||
|
||||
// Process ID filter (optional)
|
||||
PID int `json:"pid,omitempty"`
|
||||
|
||||
// Thread ID filter (optional)
|
||||
TID int `json:"tid,omitempty"`
|
||||
|
||||
// UID filter (optional)
|
||||
UID int `json:"uid,omitempty"`
|
||||
|
||||
// Process name filter (optional)
|
||||
ProcessName string `json:"process_name,omitempty"`
|
||||
}
|
||||
|
||||
// TraceEvent represents a captured event from eBPF
|
||||
type TraceEvent struct {
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
PID int `json:"pid"`
|
||||
TID int `json:"tid"`
|
||||
UID int `json:"uid"`
|
||||
ProcessName string `json:"process_name"`
|
||||
Function string `json:"function"`
|
||||
Message string `json:"message"`
|
||||
RawArgs map[string]string `json:"raw_args"`
|
||||
CPU int `json:"cpu,omitempty"`
|
||||
}
|
||||
|
||||
// TraceResult represents the results of a tracing session
|
||||
type TraceResult struct {
|
||||
TraceID string `json:"trace_id"`
|
||||
Spec TraceSpec `json:"spec"`
|
||||
Events []TraceEvent `json:"events"`
|
||||
EventCount int `json:"event_count"`
|
||||
StartTime time.Time `json:"start_time"`
|
||||
EndTime time.Time `json:"end_time"`
|
||||
Summary string `json:"summary"`
|
||||
Statistics TraceStats `json:"statistics"`
|
||||
}
|
||||
|
||||
// TraceStats provides statistics about the trace
|
||||
type TraceStats struct {
|
||||
TotalEvents int `json:"total_events"`
|
||||
EventsByProcess map[string]int `json:"events_by_process"`
|
||||
EventsByUID map[int]int `json:"events_by_uid"`
|
||||
EventsPerSecond float64 `json:"events_per_second"`
|
||||
TopProcesses []ProcessStat `json:"top_processes"`
|
||||
}
|
||||
|
||||
// ProcessStat represents statistics for a process
|
||||
type ProcessStat struct {
|
||||
ProcessName string `json:"process_name"`
|
||||
PID int `json:"pid"`
|
||||
EventCount int `json:"event_count"`
|
||||
Percentage float64 `json:"percentage"`
|
||||
}
|
||||
|
||||
// BCCTraceManager implements advanced eBPF tracing similar to BCC trace.py
|
||||
type BCCTraceManager struct {
|
||||
traces map[string]*RunningTrace
|
||||
tracesLock sync.RWMutex
|
||||
traceCounter int
|
||||
capabilities map[string]bool
|
||||
}
|
||||
|
||||
// RunningTrace represents an active trace session
|
||||
type RunningTrace struct {
|
||||
ID string
|
||||
Spec TraceSpec
|
||||
Process *exec.Cmd
|
||||
Events []TraceEvent
|
||||
StartTime time.Time
|
||||
Cancel context.CancelFunc
|
||||
Context context.Context
|
||||
Done chan struct{} // Signal when trace monitoring is complete
|
||||
}
|
||||
|
||||
// NewBCCTraceManager creates a new BCC-style trace manager
|
||||
func NewBCCTraceManager() *BCCTraceManager {
|
||||
manager := &BCCTraceManager{
|
||||
traces: make(map[string]*RunningTrace),
|
||||
capabilities: make(map[string]bool),
|
||||
}
|
||||
|
||||
manager.testCapabilities()
|
||||
return manager
|
||||
}
|
||||
|
||||
// testCapabilities checks what tracing capabilities are available
|
||||
func (tm *BCCTraceManager) testCapabilities() {
|
||||
// Test if bpftrace is available
|
||||
if _, err := exec.LookPath("bpftrace"); err == nil {
|
||||
tm.capabilities["bpftrace"] = true
|
||||
} else {
|
||||
tm.capabilities["bpftrace"] = false
|
||||
}
|
||||
|
||||
// Test if perf is available for fallback
|
||||
if _, err := exec.LookPath("perf"); err == nil {
|
||||
tm.capabilities["perf"] = true
|
||||
} else {
|
||||
tm.capabilities["perf"] = false
|
||||
}
|
||||
|
||||
// Test root privileges (required for eBPF)
|
||||
tm.capabilities["root_access"] = os.Geteuid() == 0
|
||||
|
||||
// Test kernel version
|
||||
cmd := exec.Command("uname", "-r")
|
||||
output, err := cmd.Output()
|
||||
if err == nil {
|
||||
version := strings.TrimSpace(string(output))
|
||||
// eBPF requires kernel 4.4+
|
||||
tm.capabilities["kernel_ebpf"] = !strings.HasPrefix(version, "3.")
|
||||
} else {
|
||||
tm.capabilities["kernel_ebpf"] = false
|
||||
}
|
||||
|
||||
// Test if we can access debugfs
|
||||
if _, err := os.Stat("/sys/kernel/debug/tracing/available_events"); err == nil {
|
||||
tm.capabilities["debugfs_access"] = true
|
||||
} else {
|
||||
tm.capabilities["debugfs_access"] = false
|
||||
}
|
||||
|
||||
logging.Debug("BCC Trace capabilities: %+v", tm.capabilities)
|
||||
}
|
||||
|
||||
// GetCapabilities returns available tracing capabilities
|
||||
func (tm *BCCTraceManager) GetCapabilities() map[string]bool {
|
||||
tm.tracesLock.RLock()
|
||||
defer tm.tracesLock.RUnlock()
|
||||
|
||||
caps := make(map[string]bool)
|
||||
for k, v := range tm.capabilities {
|
||||
caps[k] = v
|
||||
}
|
||||
return caps
|
||||
}
|
||||
|
||||
// StartTrace starts a new trace session based on the specification
|
||||
func (tm *BCCTraceManager) StartTrace(spec TraceSpec) (string, error) {
|
||||
if !tm.capabilities["bpftrace"] {
|
||||
return "", fmt.Errorf("bpftrace not available - install bpftrace package")
|
||||
}
|
||||
|
||||
if !tm.capabilities["root_access"] {
|
||||
return "", fmt.Errorf("root access required for eBPF tracing")
|
||||
}
|
||||
|
||||
if !tm.capabilities["kernel_ebpf"] {
|
||||
return "", fmt.Errorf("kernel version does not support eBPF")
|
||||
}
|
||||
|
||||
tm.tracesLock.Lock()
|
||||
defer tm.tracesLock.Unlock()
|
||||
|
||||
// Generate trace ID
|
||||
tm.traceCounter++
|
||||
traceID := fmt.Sprintf("trace_%d", tm.traceCounter)
|
||||
|
||||
// Generate bpftrace script
|
||||
script, err := tm.generateBpftraceScript(spec)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate bpftrace script: %w", err)
|
||||
}
|
||||
|
||||
// Debug: log the generated script
|
||||
logging.Debug("Generated bpftrace script for %s:\n%s", spec.Target, script)
|
||||
|
||||
// Create context with timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(spec.Duration)*time.Second)
|
||||
|
||||
// Start bpftrace process
|
||||
cmd := exec.CommandContext(ctx, "bpftrace", "-e", script)
|
||||
|
||||
// Create stdout pipe BEFORE starting
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
cancel()
|
||||
return "", fmt.Errorf("failed to create stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
trace := &RunningTrace{
|
||||
ID: traceID,
|
||||
Spec: spec,
|
||||
Process: cmd,
|
||||
Events: []TraceEvent{},
|
||||
StartTime: time.Now(),
|
||||
Cancel: cancel,
|
||||
Context: ctx,
|
||||
Done: make(chan struct{}), // Initialize completion signal
|
||||
}
|
||||
|
||||
// Start the trace
|
||||
if err := cmd.Start(); err != nil {
|
||||
cancel()
|
||||
return "", fmt.Errorf("failed to start bpftrace: %w", err)
|
||||
}
|
||||
|
||||
tm.traces[traceID] = trace
|
||||
|
||||
// Monitor the trace in a goroutine
|
||||
go tm.monitorTrace(traceID, stdout)
|
||||
|
||||
logging.Debug("Started BCC-style trace %s for target %s", traceID, spec.Target)
|
||||
return traceID, nil
|
||||
} // generateBpftraceScript generates a bpftrace script based on the trace specification
|
||||
func (tm *BCCTraceManager) generateBpftraceScript(spec TraceSpec) (string, error) {
|
||||
var script strings.Builder
|
||||
|
||||
// Build probe specification
|
||||
var probe string
|
||||
switch spec.ProbeType {
|
||||
case "p", "": // kprobe (default)
|
||||
if strings.HasPrefix(spec.Target, "sys_") || strings.HasPrefix(spec.Target, "__x64_sys_") {
|
||||
probe = fmt.Sprintf("kprobe:%s", spec.Target)
|
||||
} else {
|
||||
probe = fmt.Sprintf("kprobe:%s", spec.Target)
|
||||
}
|
||||
case "r": // kretprobe
|
||||
if strings.HasPrefix(spec.Target, "sys_") || strings.HasPrefix(spec.Target, "__x64_sys_") {
|
||||
probe = fmt.Sprintf("kretprobe:%s", spec.Target)
|
||||
} else {
|
||||
probe = fmt.Sprintf("kretprobe:%s", spec.Target)
|
||||
}
|
||||
case "t": // tracepoint
|
||||
// If target already includes tracepoint prefix, use as-is
|
||||
if strings.HasPrefix(spec.Target, "tracepoint:") {
|
||||
probe = spec.Target
|
||||
} else {
|
||||
probe = fmt.Sprintf("tracepoint:%s", spec.Target)
|
||||
}
|
||||
case "u": // uprobe
|
||||
if spec.Library == "" {
|
||||
return "", fmt.Errorf("library required for uprobe")
|
||||
}
|
||||
probe = fmt.Sprintf("uprobe:%s:%s", spec.Library, spec.Target)
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported probe type: %s", spec.ProbeType)
|
||||
}
|
||||
|
||||
// Add BEGIN block
|
||||
script.WriteString("BEGIN {\n")
|
||||
script.WriteString(fmt.Sprintf(" printf(\"Starting trace for %s...\\n\");\n", spec.Target))
|
||||
script.WriteString("}\n\n")
|
||||
|
||||
// Build the main probe
|
||||
script.WriteString(fmt.Sprintf("%s {\n", probe))
|
||||
|
||||
// Add filters if specified
|
||||
if tm.needsFiltering(spec) {
|
||||
script.WriteString(" if (")
|
||||
filters := tm.buildFilters(spec)
|
||||
script.WriteString(strings.Join(filters, " && "))
|
||||
script.WriteString(") {\n")
|
||||
}
|
||||
|
||||
// Build output format
|
||||
outputFormat := tm.buildOutputFormat(spec)
|
||||
script.WriteString(fmt.Sprintf(" printf(\"%s\\n\"", outputFormat))
|
||||
|
||||
// Add arguments
|
||||
args := tm.buildArgumentList(spec)
|
||||
if len(args) > 0 {
|
||||
script.WriteString(", ")
|
||||
script.WriteString(strings.Join(args, ", "))
|
||||
}
|
||||
|
||||
script.WriteString(");\n")
|
||||
|
||||
// Close filter if block
|
||||
if tm.needsFiltering(spec) {
|
||||
script.WriteString(" }\n")
|
||||
}
|
||||
|
||||
script.WriteString("}\n\n")
|
||||
|
||||
// Add END block
|
||||
script.WriteString("END {\n")
|
||||
script.WriteString(fmt.Sprintf(" printf(\"Trace completed for %s\\n\");\n", spec.Target))
|
||||
script.WriteString("}\n")
|
||||
|
||||
return script.String(), nil
|
||||
}
|
||||
|
||||
// needsFiltering checks if any filters are needed
|
||||
func (tm *BCCTraceManager) needsFiltering(spec TraceSpec) bool {
|
||||
return spec.PID != 0 || spec.TID != 0 || spec.UID != -1 ||
|
||||
spec.ProcessName != "" || spec.Filter != ""
|
||||
}
|
||||
|
||||
// buildFilters builds the filter conditions
|
||||
func (tm *BCCTraceManager) buildFilters(spec TraceSpec) []string {
|
||||
var filters []string
|
||||
|
||||
if spec.PID != 0 {
|
||||
filters = append(filters, fmt.Sprintf("pid == %d", spec.PID))
|
||||
}
|
||||
|
||||
if spec.TID != 0 {
|
||||
filters = append(filters, fmt.Sprintf("tid == %d", spec.TID))
|
||||
}
|
||||
|
||||
if spec.UID != -1 {
|
||||
filters = append(filters, fmt.Sprintf("uid == %d", spec.UID))
|
||||
}
|
||||
|
||||
if spec.ProcessName != "" {
|
||||
filters = append(filters, fmt.Sprintf("strncmp(comm, \"%s\", %d) == 0", spec.ProcessName, len(spec.ProcessName)))
|
||||
}
|
||||
|
||||
// Add custom filter
|
||||
if spec.Filter != "" {
|
||||
// Convert common patterns to bpftrace syntax
|
||||
customFilter := strings.ReplaceAll(spec.Filter, "arg", "arg")
|
||||
filters = append(filters, customFilter)
|
||||
}
|
||||
|
||||
return filters
|
||||
}
|
||||
|
||||
// buildOutputFormat creates the output format string
|
||||
func (tm *BCCTraceManager) buildOutputFormat(spec TraceSpec) string {
|
||||
if spec.Format != "" {
|
||||
// Use custom format
|
||||
return fmt.Sprintf("TRACE|%%d|%%d|%%d|%%s|%s|%s", spec.Target, spec.Format)
|
||||
}
|
||||
|
||||
// Default format
|
||||
return fmt.Sprintf("TRACE|%%d|%%d|%%d|%%s|%s|called", spec.Target)
|
||||
}
|
||||
|
||||
// buildArgumentList creates the argument list for printf
|
||||
func (tm *BCCTraceManager) buildArgumentList(spec TraceSpec) []string {
|
||||
// Always include timestamp, pid, tid, comm
|
||||
args := []string{"nsecs", "pid", "tid", "comm"}
|
||||
|
||||
// Add custom arguments
|
||||
for _, arg := range spec.Arguments {
|
||||
switch arg {
|
||||
case "arg1", "arg2", "arg3", "arg4", "arg5", "arg6":
|
||||
args = append(args, fmt.Sprintf("arg%s", strings.TrimPrefix(arg, "arg")))
|
||||
case "retval":
|
||||
args = append(args, "retval")
|
||||
case "cpu":
|
||||
args = append(args, "cpu")
|
||||
default:
|
||||
// Custom expression
|
||||
args = append(args, arg)
|
||||
}
|
||||
}
|
||||
|
||||
return args
|
||||
}
|
||||
|
||||
// monitorTrace monitors a running trace and collects events
|
||||
func (tm *BCCTraceManager) monitorTrace(traceID string, stdout io.ReadCloser) {
|
||||
tm.tracesLock.Lock()
|
||||
trace, exists := tm.traces[traceID]
|
||||
if !exists {
|
||||
tm.tracesLock.Unlock()
|
||||
return
|
||||
}
|
||||
tm.tracesLock.Unlock()
|
||||
|
||||
// Start reading output in a goroutine
|
||||
go func() {
|
||||
scanner := NewEventScanner(stdout)
|
||||
for scanner.Scan() {
|
||||
event := scanner.Event()
|
||||
if event != nil {
|
||||
tm.tracesLock.Lock()
|
||||
if t, exists := tm.traces[traceID]; exists {
|
||||
t.Events = append(t.Events, *event)
|
||||
}
|
||||
tm.tracesLock.Unlock()
|
||||
}
|
||||
}
|
||||
stdout.Close()
|
||||
}()
|
||||
|
||||
// Wait for the process to complete
|
||||
err := trace.Process.Wait()
|
||||
|
||||
// Clean up
|
||||
trace.Cancel()
|
||||
|
||||
tm.tracesLock.Lock()
|
||||
if err != nil && err.Error() != "signal: killed" {
|
||||
logging.Warning("Trace %s completed with error: %v", traceID, err)
|
||||
} else {
|
||||
logging.Debug("Trace %s completed successfully with %d events",
|
||||
traceID, len(trace.Events))
|
||||
}
|
||||
|
||||
// Signal that monitoring is complete
|
||||
close(trace.Done)
|
||||
tm.tracesLock.Unlock()
|
||||
}
|
||||
|
||||
// GetTraceResult returns the results of a completed trace
|
||||
func (tm *BCCTraceManager) GetTraceResult(traceID string) (*TraceResult, error) {
|
||||
tm.tracesLock.RLock()
|
||||
trace, exists := tm.traces[traceID]
|
||||
if !exists {
|
||||
tm.tracesLock.RUnlock()
|
||||
return nil, fmt.Errorf("trace %s not found", traceID)
|
||||
}
|
||||
tm.tracesLock.RUnlock()
|
||||
|
||||
// Wait for trace monitoring to complete
|
||||
select {
|
||||
case <-trace.Done:
|
||||
// Trace monitoring completed
|
||||
case <-time.After(5 * time.Second):
|
||||
// Timeout waiting for completion
|
||||
return nil, fmt.Errorf("timeout waiting for trace %s to complete", traceID)
|
||||
}
|
||||
|
||||
// Now safely read the final results
|
||||
tm.tracesLock.RLock()
|
||||
defer tm.tracesLock.RUnlock()
|
||||
|
||||
result := &TraceResult{
|
||||
TraceID: traceID,
|
||||
Spec: trace.Spec,
|
||||
Events: make([]TraceEvent, len(trace.Events)),
|
||||
EventCount: len(trace.Events),
|
||||
StartTime: trace.StartTime,
|
||||
EndTime: time.Now(),
|
||||
}
|
||||
|
||||
copy(result.Events, trace.Events)
|
||||
|
||||
// Calculate statistics
|
||||
result.Statistics = tm.calculateStatistics(result.Events, result.EndTime.Sub(result.StartTime))
|
||||
|
||||
// Generate summary
|
||||
result.Summary = tm.generateSummary(result)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// calculateStatistics calculates statistics for the trace results
|
||||
func (tm *BCCTraceManager) calculateStatistics(events []TraceEvent, duration time.Duration) TraceStats {
|
||||
stats := TraceStats{
|
||||
TotalEvents: len(events),
|
||||
EventsByProcess: make(map[string]int),
|
||||
EventsByUID: make(map[int]int),
|
||||
}
|
||||
|
||||
if duration > 0 {
|
||||
stats.EventsPerSecond = float64(len(events)) / duration.Seconds()
|
||||
}
|
||||
|
||||
// Calculate per-process and per-UID statistics
|
||||
for _, event := range events {
|
||||
stats.EventsByProcess[event.ProcessName]++
|
||||
stats.EventsByUID[event.UID]++
|
||||
}
|
||||
|
||||
// Calculate top processes
|
||||
for processName, count := range stats.EventsByProcess {
|
||||
percentage := float64(count) / float64(len(events)) * 100
|
||||
stats.TopProcesses = append(stats.TopProcesses, ProcessStat{
|
||||
ProcessName: processName,
|
||||
EventCount: count,
|
||||
Percentage: percentage,
|
||||
})
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// generateSummary generates a human-readable summary
|
||||
func (tm *BCCTraceManager) generateSummary(result *TraceResult) string {
|
||||
duration := result.EndTime.Sub(result.StartTime)
|
||||
|
||||
summary := fmt.Sprintf("Traced %s for %v, captured %d events (%.2f events/sec)",
|
||||
result.Spec.Target, duration, result.EventCount, result.Statistics.EventsPerSecond)
|
||||
|
||||
if len(result.Statistics.TopProcesses) > 0 {
|
||||
summary += fmt.Sprintf(", top process: %s (%d events)",
|
||||
result.Statistics.TopProcesses[0].ProcessName,
|
||||
result.Statistics.TopProcesses[0].EventCount)
|
||||
}
|
||||
|
||||
return summary
|
||||
}
|
||||
|
||||
// StopTrace stops an active trace
|
||||
func (tm *BCCTraceManager) StopTrace(traceID string) error {
|
||||
tm.tracesLock.Lock()
|
||||
defer tm.tracesLock.Unlock()
|
||||
|
||||
trace, exists := tm.traces[traceID]
|
||||
if !exists {
|
||||
return fmt.Errorf("trace %s not found", traceID)
|
||||
}
|
||||
|
||||
if trace.Process.ProcessState == nil {
|
||||
// Process is still running, kill it
|
||||
if err := trace.Process.Process.Kill(); err != nil {
|
||||
return fmt.Errorf("failed to stop trace: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
trace.Cancel()
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListActiveTraces returns a list of active trace IDs
|
||||
func (tm *BCCTraceManager) ListActiveTraces() []string {
|
||||
tm.tracesLock.RLock()
|
||||
defer tm.tracesLock.RUnlock()
|
||||
|
||||
var active []string
|
||||
for id, trace := range tm.traces {
|
||||
if trace.Process.ProcessState == nil {
|
||||
active = append(active, id)
|
||||
}
|
||||
}
|
||||
|
||||
return active
|
||||
}
|
||||
|
||||
// GetSummary returns a summary of the trace manager state
|
||||
func (tm *BCCTraceManager) GetSummary() map[string]interface{} {
|
||||
tm.tracesLock.RLock()
|
||||
defer tm.tracesLock.RUnlock()
|
||||
|
||||
activeCount := 0
|
||||
completedCount := 0
|
||||
|
||||
for _, trace := range tm.traces {
|
||||
if trace.Process.ProcessState == nil {
|
||||
activeCount++
|
||||
} else {
|
||||
completedCount++
|
||||
}
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"capabilities": tm.capabilities,
|
||||
"active_traces": activeCount,
|
||||
"completed_traces": completedCount,
|
||||
"total_traces": len(tm.traces),
|
||||
"active_trace_ids": tm.ListActiveTraces(),
|
||||
}
|
||||
}
|
||||
396
internal/ebpf/ebpf_trace_specs.go
Normal file
396
internal/ebpf/ebpf_trace_specs.go
Normal file
@@ -0,0 +1,396 @@
|
||||
package ebpf
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// TestTraceSpecs provides test trace specifications for unit testing the BCC-style tracing
|
||||
// These are used to validate the tracing functionality without requiring remote API calls
|
||||
var TestTraceSpecs = map[string]TraceSpec{
|
||||
// Basic system call tracing for testing
|
||||
"test_sys_open": {
|
||||
ProbeType: "p",
|
||||
Target: "__x64_sys_openat",
|
||||
Format: "opening file: %s",
|
||||
Arguments: []string{"arg2@user"}, // filename
|
||||
Duration: 5, // Short duration for testing
|
||||
},
|
||||
|
||||
"test_sys_read": {
|
||||
ProbeType: "p",
|
||||
Target: "__x64_sys_read",
|
||||
Format: "read %d bytes from fd %d",
|
||||
Arguments: []string{"arg3", "arg1"}, // count, fd
|
||||
Filter: "arg3 > 100", // Only reads >100 bytes for testing
|
||||
Duration: 5,
|
||||
},
|
||||
|
||||
"test_sys_write": {
|
||||
ProbeType: "p",
|
||||
Target: "__x64_sys_write",
|
||||
Format: "write %d bytes to fd %d",
|
||||
Arguments: []string{"arg3", "arg1"}, // count, fd
|
||||
Duration: 5,
|
||||
},
|
||||
|
||||
"test_process_creation": {
|
||||
ProbeType: "p",
|
||||
Target: "__x64_sys_execve",
|
||||
Format: "exec: %s",
|
||||
Arguments: []string{"arg1@user"}, // filename
|
||||
Duration: 5,
|
||||
},
|
||||
|
||||
// Test with different probe types
|
||||
"test_kretprobe": {
|
||||
ProbeType: "r",
|
||||
Target: "__x64_sys_openat",
|
||||
Format: "open returned: %d",
|
||||
Arguments: []string{"retval"},
|
||||
Duration: 5,
|
||||
},
|
||||
|
||||
"test_with_filter": {
|
||||
ProbeType: "p",
|
||||
Target: "__x64_sys_write",
|
||||
Format: "stdout write: %d bytes",
|
||||
Arguments: []string{"arg3"},
|
||||
Filter: "arg1 == 1", // Only stdout writes
|
||||
Duration: 5,
|
||||
},
|
||||
}
|
||||
|
||||
// GetTestSpec returns a pre-defined test trace specification
|
||||
func GetTestSpec(name string) (TraceSpec, bool) {
|
||||
spec, exists := TestTraceSpecs[name]
|
||||
return spec, exists
|
||||
}
|
||||
|
||||
// ListTestSpecs returns all available test trace specifications
|
||||
func ListTestSpecs() map[string]string {
|
||||
descriptions := map[string]string{
|
||||
"test_sys_open": "Test file open operations",
|
||||
"test_sys_read": "Test read operations (>100 bytes)",
|
||||
"test_sys_write": "Test write operations",
|
||||
"test_process_creation": "Test process execution",
|
||||
"test_kretprobe": "Test kretprobe on file open",
|
||||
"test_with_filter": "Test filtered writes to stdout",
|
||||
}
|
||||
|
||||
return descriptions
|
||||
}
|
||||
|
||||
// TraceSpecBuilder helps build custom trace specifications
|
||||
type TraceSpecBuilder struct {
|
||||
spec TraceSpec
|
||||
}
|
||||
|
||||
// NewTraceSpecBuilder creates a new trace specification builder
|
||||
func NewTraceSpecBuilder() *TraceSpecBuilder {
|
||||
return &TraceSpecBuilder{
|
||||
spec: TraceSpec{
|
||||
ProbeType: "p", // Default to kprobe
|
||||
Duration: 30, // Default 30 seconds
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Kprobe sets up a kernel probe
|
||||
func (b *TraceSpecBuilder) Kprobe(function string) *TraceSpecBuilder {
|
||||
b.spec.ProbeType = "p"
|
||||
b.spec.Target = function
|
||||
return b
|
||||
}
|
||||
|
||||
// Kretprobe sets up a kernel return probe
|
||||
func (b *TraceSpecBuilder) Kretprobe(function string) *TraceSpecBuilder {
|
||||
b.spec.ProbeType = "r"
|
||||
b.spec.Target = function
|
||||
return b
|
||||
}
|
||||
|
||||
// Tracepoint sets up a tracepoint
|
||||
func (b *TraceSpecBuilder) Tracepoint(category, name string) *TraceSpecBuilder {
|
||||
b.spec.ProbeType = "t"
|
||||
b.spec.Target = fmt.Sprintf("%s:%s", category, name)
|
||||
return b
|
||||
}
|
||||
|
||||
// Uprobe sets up a userspace probe
|
||||
func (b *TraceSpecBuilder) Uprobe(library, function string) *TraceSpecBuilder {
|
||||
b.spec.ProbeType = "u"
|
||||
b.spec.Library = library
|
||||
b.spec.Target = function
|
||||
return b
|
||||
}
|
||||
|
||||
// Format sets the output format string
|
||||
func (b *TraceSpecBuilder) Format(format string, args ...string) *TraceSpecBuilder {
|
||||
b.spec.Format = format
|
||||
b.spec.Arguments = args
|
||||
return b
|
||||
}
|
||||
|
||||
// Filter adds a filter condition
|
||||
func (b *TraceSpecBuilder) Filter(condition string) *TraceSpecBuilder {
|
||||
b.spec.Filter = condition
|
||||
return b
|
||||
}
|
||||
|
||||
// Duration sets the trace duration in seconds
|
||||
func (b *TraceSpecBuilder) Duration(seconds int) *TraceSpecBuilder {
|
||||
b.spec.Duration = seconds
|
||||
return b
|
||||
}
|
||||
|
||||
// PID filters by process ID
|
||||
func (b *TraceSpecBuilder) PID(pid int) *TraceSpecBuilder {
|
||||
b.spec.PID = pid
|
||||
return b
|
||||
}
|
||||
|
||||
// UID filters by user ID
|
||||
func (b *TraceSpecBuilder) UID(uid int) *TraceSpecBuilder {
|
||||
b.spec.UID = uid
|
||||
return b
|
||||
}
|
||||
|
||||
// ProcessName filters by process name
|
||||
func (b *TraceSpecBuilder) ProcessName(name string) *TraceSpecBuilder {
|
||||
b.spec.ProcessName = name
|
||||
return b
|
||||
}
|
||||
|
||||
// Build returns the constructed trace specification
|
||||
func (b *TraceSpecBuilder) Build() TraceSpec {
|
||||
return b.spec
|
||||
}
|
||||
|
||||
// TraceSpecParser parses trace specifications from various formats
|
||||
type TraceSpecParser struct{}
|
||||
|
||||
// NewTraceSpecParser creates a new parser
|
||||
func NewTraceSpecParser() *TraceSpecParser {
|
||||
return &TraceSpecParser{}
|
||||
}
|
||||
|
||||
// ParseFromBCCStyle parses BCC trace.py style specifications
|
||||
// Examples:
|
||||
//
|
||||
// "sys_open" -> trace sys_open syscall
|
||||
// "p::do_sys_open" -> kprobe on do_sys_open
|
||||
// "r::do_sys_open" -> kretprobe on do_sys_open
|
||||
// "t:syscalls:sys_enter_open" -> tracepoint
|
||||
// "sys_read (arg3 > 1024)" -> with filter
|
||||
// "sys_read \"read %d bytes\", arg3" -> with format
|
||||
func (p *TraceSpecParser) ParseFromBCCStyle(spec string) (TraceSpec, error) {
|
||||
result := TraceSpec{
|
||||
ProbeType: "p",
|
||||
Duration: 30,
|
||||
}
|
||||
|
||||
// Split by quotes to separate format string
|
||||
parts := strings.Split(spec, "\"")
|
||||
|
||||
var probeSpec string
|
||||
if len(parts) >= 1 {
|
||||
probeSpec = strings.TrimSpace(parts[0])
|
||||
}
|
||||
|
||||
var formatPart string
|
||||
if len(parts) >= 2 {
|
||||
formatPart = parts[1]
|
||||
}
|
||||
|
||||
var argsPart string
|
||||
if len(parts) >= 3 {
|
||||
argsPart = strings.TrimSpace(parts[2])
|
||||
if strings.HasPrefix(argsPart, ",") {
|
||||
argsPart = strings.TrimSpace(argsPart[1:])
|
||||
}
|
||||
}
|
||||
|
||||
// Parse probe specification
|
||||
if err := p.parseProbeSpec(probeSpec, &result); err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
// Parse format string
|
||||
if formatPart != "" {
|
||||
result.Format = formatPart
|
||||
}
|
||||
|
||||
// Parse arguments
|
||||
if argsPart != "" {
|
||||
result.Arguments = p.parseArguments(argsPart)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// parseProbeSpec parses the probe specification part
|
||||
func (p *TraceSpecParser) parseProbeSpec(spec string, result *TraceSpec) error {
|
||||
// Handle filter conditions in parentheses
|
||||
if idx := strings.Index(spec, "("); idx != -1 {
|
||||
filterEnd := strings.LastIndex(spec, ")")
|
||||
if filterEnd > idx {
|
||||
result.Filter = strings.TrimSpace(spec[idx+1 : filterEnd])
|
||||
spec = strings.TrimSpace(spec[:idx])
|
||||
}
|
||||
}
|
||||
|
||||
// Parse probe type and target
|
||||
if strings.Contains(spec, ":") {
|
||||
parts := strings.SplitN(spec, ":", 3)
|
||||
|
||||
if len(parts) >= 1 && parts[0] != "" {
|
||||
switch parts[0] {
|
||||
case "p":
|
||||
result.ProbeType = "p"
|
||||
case "r":
|
||||
result.ProbeType = "r"
|
||||
case "t":
|
||||
result.ProbeType = "t"
|
||||
case "u":
|
||||
result.ProbeType = "u"
|
||||
default:
|
||||
return fmt.Errorf("unsupported probe type: %s", parts[0])
|
||||
}
|
||||
}
|
||||
|
||||
if len(parts) >= 2 {
|
||||
result.Library = parts[1]
|
||||
}
|
||||
|
||||
if len(parts) >= 3 {
|
||||
result.Target = parts[2]
|
||||
} else if len(parts) == 2 {
|
||||
result.Target = parts[1]
|
||||
result.Library = ""
|
||||
}
|
||||
} else {
|
||||
// Simple function name
|
||||
result.Target = spec
|
||||
|
||||
// Auto-detect syscall format
|
||||
if strings.HasPrefix(spec, "sys_") && !strings.HasPrefix(spec, "__x64_sys_") {
|
||||
result.Target = "__x64_sys_" + spec[4:]
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseArguments parses the arguments part
|
||||
func (p *TraceSpecParser) parseArguments(args string) []string {
|
||||
var result []string
|
||||
|
||||
// Split by comma and clean up
|
||||
parts := strings.Split(args, ",")
|
||||
for _, part := range parts {
|
||||
arg := strings.TrimSpace(part)
|
||||
if arg != "" {
|
||||
result = append(result, arg)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ParseFromJSON parses trace specification from JSON
|
||||
func (p *TraceSpecParser) ParseFromJSON(jsonData []byte) (TraceSpec, error) {
|
||||
var spec TraceSpec
|
||||
err := json.Unmarshal(jsonData, &spec)
|
||||
return spec, err
|
||||
}
|
||||
|
||||
// GetCommonSpec returns a pre-defined test trace specification (renamed for backward compatibility)
|
||||
func GetCommonSpec(name string) (TraceSpec, bool) {
|
||||
// Map old names to new test names for compatibility
|
||||
testName := name
|
||||
if strings.HasPrefix(name, "trace_") {
|
||||
testName = strings.Replace(name, "trace_", "test_", 1)
|
||||
}
|
||||
|
||||
spec, exists := TestTraceSpecs[testName]
|
||||
return spec, exists
|
||||
}
|
||||
|
||||
// ListCommonSpecs returns all available test trace specifications (renamed for backward compatibility)
|
||||
func ListCommonSpecs() map[string]string {
|
||||
return ListTestSpecs()
|
||||
}
|
||||
|
||||
// ValidateTraceSpec validates a trace specification
|
||||
func ValidateTraceSpec(spec TraceSpec) error {
|
||||
if spec.Target == "" {
|
||||
return fmt.Errorf("target function/syscall is required")
|
||||
}
|
||||
|
||||
if spec.Duration <= 0 {
|
||||
return fmt.Errorf("duration must be positive")
|
||||
}
|
||||
|
||||
if spec.Duration > 600 { // 10 minutes max
|
||||
return fmt.Errorf("duration too long (max 600 seconds)")
|
||||
}
|
||||
|
||||
switch spec.ProbeType {
|
||||
case "p", "r", "t", "u":
|
||||
// Valid probe types
|
||||
case "":
|
||||
// Default to kprobe
|
||||
default:
|
||||
return fmt.Errorf("unsupported probe type: %s", spec.ProbeType)
|
||||
}
|
||||
|
||||
if spec.ProbeType == "u" && spec.Library == "" {
|
||||
return fmt.Errorf("library required for userspace probes")
|
||||
}
|
||||
|
||||
if spec.ProbeType == "t" && !strings.Contains(spec.Target, ":") {
|
||||
return fmt.Errorf("tracepoint requires format 'category:name'")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SuggestSyscallTargets suggests syscall targets based on the issue description
|
||||
func SuggestSyscallTargets(issueDescription string) []string {
|
||||
description := strings.ToLower(issueDescription)
|
||||
var suggestions []string
|
||||
|
||||
// File I/O issues
|
||||
if strings.Contains(description, "file") || strings.Contains(description, "disk") || strings.Contains(description, "io") {
|
||||
suggestions = append(suggestions, "trace_sys_open", "trace_sys_read", "trace_sys_write", "trace_sys_unlink")
|
||||
}
|
||||
|
||||
// Network issues
|
||||
if strings.Contains(description, "network") || strings.Contains(description, "socket") || strings.Contains(description, "connection") {
|
||||
suggestions = append(suggestions, "trace_sys_connect", "trace_sys_socket", "trace_sys_bind", "trace_sys_accept")
|
||||
}
|
||||
|
||||
// Process issues
|
||||
if strings.Contains(description, "process") || strings.Contains(description, "crash") || strings.Contains(description, "exec") {
|
||||
suggestions = append(suggestions, "trace_sys_execve", "trace_sys_clone", "trace_sys_exit", "trace_sys_kill")
|
||||
}
|
||||
|
||||
// Memory issues
|
||||
if strings.Contains(description, "memory") || strings.Contains(description, "malloc") || strings.Contains(description, "leak") {
|
||||
suggestions = append(suggestions, "trace_sys_mmap", "trace_sys_brk")
|
||||
}
|
||||
|
||||
// Performance issues - trace common syscalls
|
||||
if strings.Contains(description, "slow") || strings.Contains(description, "performance") || strings.Contains(description, "hang") {
|
||||
suggestions = append(suggestions, "trace_sys_read", "trace_sys_write", "trace_sys_connect", "trace_sys_mmap")
|
||||
}
|
||||
|
||||
// If no specific suggestions, provide general monitoring
|
||||
if len(suggestions) == 0 {
|
||||
suggestions = append(suggestions, "trace_sys_execve", "trace_sys_open", "trace_sys_connect")
|
||||
}
|
||||
|
||||
return suggestions
|
||||
}
|
||||
921
internal/ebpf/ebpf_trace_test.go
Normal file
921
internal/ebpf/ebpf_trace_test.go
Normal file
@@ -0,0 +1,921 @@
|
||||
package ebpf
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestBCCTracing demonstrates and tests the new BCC-style tracing functionality
|
||||
// This test documents the expected behavior and response format of the agent
|
||||
func TestBCCTracing(t *testing.T) {
|
||||
fmt.Println("=== BCC-Style eBPF Tracing Unit Tests ===")
|
||||
fmt.Println()
|
||||
|
||||
// Test 1: List available test specifications
|
||||
t.Run("ListTestSpecs", func(t *testing.T) {
|
||||
specs := ListTestSpecs()
|
||||
fmt.Printf("📋 Available Test Specifications:\n")
|
||||
for name, description := range specs {
|
||||
fmt.Printf(" - %s: %s\n", name, description)
|
||||
}
|
||||
fmt.Println()
|
||||
|
||||
if len(specs) == 0 {
|
||||
t.Error("No test specifications available")
|
||||
}
|
||||
})
|
||||
|
||||
// Test 2: Parse BCC-style specifications
|
||||
t.Run("ParseBCCStyle", func(t *testing.T) {
|
||||
parser := NewTraceSpecParser()
|
||||
|
||||
testCases := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
input: "sys_open",
|
||||
expected: "__x64_sys_open",
|
||||
},
|
||||
{
|
||||
input: "p::do_sys_open",
|
||||
expected: "do_sys_open",
|
||||
},
|
||||
{
|
||||
input: "r::sys_read",
|
||||
expected: "sys_read",
|
||||
},
|
||||
{
|
||||
input: "sys_write (arg1 == 1)",
|
||||
expected: "__x64_sys_write",
|
||||
},
|
||||
}
|
||||
|
||||
fmt.Printf("🔍 Testing BCC-style parsing:\n")
|
||||
for _, tc := range testCases {
|
||||
spec, err := parser.ParseFromBCCStyle(tc.input)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to parse '%s': %v", tc.input, err)
|
||||
continue
|
||||
}
|
||||
|
||||
fmt.Printf(" Input: '%s' -> Target: '%s', Type: '%s'\n",
|
||||
tc.input, spec.Target, spec.ProbeType)
|
||||
|
||||
if spec.Target != tc.expected {
|
||||
t.Errorf("Expected target '%s', got '%s'", tc.expected, spec.Target)
|
||||
}
|
||||
}
|
||||
fmt.Println()
|
||||
})
|
||||
|
||||
// Test 3: Validate trace specifications
|
||||
t.Run("ValidateSpecs", func(t *testing.T) {
|
||||
fmt.Printf("✅ Testing trace specification validation:\n")
|
||||
|
||||
// Valid spec
|
||||
validSpec := TraceSpec{
|
||||
ProbeType: "p",
|
||||
Target: "__x64_sys_openat",
|
||||
Format: "opening file",
|
||||
Duration: 5,
|
||||
}
|
||||
|
||||
if err := ValidateTraceSpec(validSpec); err != nil {
|
||||
t.Errorf("Valid spec failed validation: %v", err)
|
||||
} else {
|
||||
fmt.Printf(" ✓ Valid specification passed\n")
|
||||
}
|
||||
|
||||
// Invalid spec - no target
|
||||
invalidSpec := TraceSpec{
|
||||
ProbeType: "p",
|
||||
Duration: 5,
|
||||
}
|
||||
|
||||
if err := ValidateTraceSpec(invalidSpec); err == nil {
|
||||
t.Error("Invalid spec (no target) should have failed validation")
|
||||
} else {
|
||||
fmt.Printf(" ✓ Invalid specification correctly rejected: %s\n", err.Error())
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
})
|
||||
|
||||
// Test 4: Simulate agent response format
|
||||
t.Run("SimulateAgentResponse", func(t *testing.T) {
|
||||
fmt.Printf("🤖 Simulating agent response for BCC-style tracing:\n")
|
||||
|
||||
// Get a test specification
|
||||
testSpec, exists := GetTestSpec("test_sys_open")
|
||||
if !exists {
|
||||
t.Fatal("test_sys_open specification not found")
|
||||
}
|
||||
|
||||
// Simulate what the agent would return
|
||||
mockResponse := simulateTraceExecution(testSpec)
|
||||
|
||||
// Print the response format
|
||||
responseJSON, _ := json.MarshalIndent(mockResponse, "", " ")
|
||||
fmt.Printf(" Expected Response Format:\n%s\n", string(responseJSON))
|
||||
|
||||
// Validate response structure
|
||||
if mockResponse["success"] != true {
|
||||
t.Error("Expected successful trace execution")
|
||||
}
|
||||
|
||||
if mockResponse["type"] != "bcc_trace" {
|
||||
t.Error("Expected type to be 'bcc_trace'")
|
||||
}
|
||||
|
||||
events, hasEvents := mockResponse["events"].([]TraceEvent)
|
||||
if !hasEvents || len(events) == 0 {
|
||||
t.Error("Expected trace events in response")
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
})
|
||||
|
||||
// Test 5: Test different probe types
|
||||
t.Run("TestProbeTypes", func(t *testing.T) {
|
||||
fmt.Printf("🔬 Testing different probe types:\n")
|
||||
|
||||
probeTests := []struct {
|
||||
specName string
|
||||
expected string
|
||||
}{
|
||||
{"test_sys_open", "kprobe"},
|
||||
{"test_kretprobe", "kretprobe"},
|
||||
{"test_with_filter", "kprobe with filter"},
|
||||
}
|
||||
|
||||
for _, test := range probeTests {
|
||||
spec, exists := GetTestSpec(test.specName)
|
||||
if !exists {
|
||||
t.Errorf("Test spec '%s' not found", test.specName)
|
||||
continue
|
||||
}
|
||||
|
||||
response := simulateTraceExecution(spec)
|
||||
fmt.Printf(" %s -> %s: %d events captured\n",
|
||||
test.specName, test.expected, response["event_count"])
|
||||
}
|
||||
fmt.Println()
|
||||
})
|
||||
|
||||
// Test 6: Test trace spec builder
|
||||
t.Run("TestTraceSpecBuilder", func(t *testing.T) {
|
||||
fmt.Printf("🏗️ Testing trace specification builder:\n")
|
||||
|
||||
// Build a custom trace spec
|
||||
spec := NewTraceSpecBuilder().
|
||||
Kprobe("__x64_sys_write").
|
||||
Format("write syscall: %d bytes", "arg3").
|
||||
Filter("arg1 == 1").
|
||||
Duration(3).
|
||||
Build()
|
||||
|
||||
fmt.Printf(" Built spec: Target=%s, Format=%s, Filter=%s\n",
|
||||
spec.Target, spec.Format, spec.Filter)
|
||||
|
||||
if spec.Target != "__x64_sys_write" {
|
||||
t.Error("Builder failed to set target correctly")
|
||||
}
|
||||
|
||||
if spec.ProbeType != "p" {
|
||||
t.Error("Builder failed to set probe type correctly")
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
})
|
||||
}
|
||||
|
||||
// simulateTraceExecution simulates what the agent would return for a trace execution
|
||||
// This documents the expected response format from the agent
|
||||
func simulateTraceExecution(spec TraceSpec) map[string]interface{} {
|
||||
// Simulate some trace events
|
||||
events := []TraceEvent{
|
||||
{
|
||||
Timestamp: time.Now().Unix(),
|
||||
PID: 1234,
|
||||
TID: 1234,
|
||||
ProcessName: "test_process",
|
||||
Function: spec.Target,
|
||||
Message: fmt.Sprintf(spec.Format, "test_file.txt"),
|
||||
RawArgs: map[string]string{
|
||||
"arg1": "5",
|
||||
"arg2": "test_file.txt",
|
||||
"arg3": "1024",
|
||||
},
|
||||
},
|
||||
{
|
||||
Timestamp: time.Now().Unix(),
|
||||
PID: 5678,
|
||||
TID: 5678,
|
||||
ProcessName: "another_process",
|
||||
Function: spec.Target,
|
||||
Message: fmt.Sprintf(spec.Format, "data.log"),
|
||||
RawArgs: map[string]string{
|
||||
"arg1": "3",
|
||||
"arg2": "data.log",
|
||||
"arg3": "512",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Simulate trace statistics
|
||||
stats := TraceStats{
|
||||
TotalEvents: len(events),
|
||||
EventsByProcess: map[string]int{"test_process": 1, "another_process": 1},
|
||||
EventsByUID: map[int]int{1000: 2},
|
||||
EventsPerSecond: float64(len(events)) / float64(spec.Duration),
|
||||
TopProcesses: []ProcessStat{
|
||||
{ProcessName: "test_process", EventCount: 1, Percentage: 50.0},
|
||||
{ProcessName: "another_process", EventCount: 1, Percentage: 50.0},
|
||||
},
|
||||
}
|
||||
|
||||
// Return the expected agent response format
|
||||
return map[string]interface{}{
|
||||
"name": spec.Target,
|
||||
"type": "bcc_trace",
|
||||
"target": spec.Target,
|
||||
"duration": spec.Duration,
|
||||
"description": fmt.Sprintf("Traced %s for %d seconds", spec.Target, spec.Duration),
|
||||
"status": "completed",
|
||||
"success": true,
|
||||
"event_count": len(events),
|
||||
"events": events,
|
||||
"statistics": stats,
|
||||
"data_points": len(events),
|
||||
"probe_type": spec.ProbeType,
|
||||
"format": spec.Format,
|
||||
"filter": spec.Filter,
|
||||
}
|
||||
}
|
||||
|
||||
// TestTraceManagerCapabilities tests the trace manager capabilities
|
||||
func TestTraceManagerCapabilities(t *testing.T) {
|
||||
fmt.Println("=== BCC Trace Manager Capabilities Test ===")
|
||||
fmt.Println()
|
||||
|
||||
manager := NewBCCTraceManager()
|
||||
caps := manager.GetCapabilities()
|
||||
|
||||
fmt.Printf("🔧 Trace Manager Capabilities:\n")
|
||||
for capability, available := range caps {
|
||||
status := "❌ Not Available"
|
||||
if available {
|
||||
status = "✅ Available"
|
||||
}
|
||||
fmt.Printf(" %s: %s\n", capability, status)
|
||||
}
|
||||
fmt.Println()
|
||||
|
||||
// Check essential capabilities
|
||||
if !caps["kernel_ebpf"] {
|
||||
fmt.Printf("⚠️ Warning: Kernel eBPF support not detected\n")
|
||||
}
|
||||
|
||||
if !caps["bpftrace"] {
|
||||
fmt.Printf("⚠️ Warning: bpftrace not available (install with: apt install bpftrace)\n")
|
||||
}
|
||||
|
||||
if !caps["root_access"] {
|
||||
fmt.Printf("⚠️ Warning: Root access required for eBPF tracing\n")
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkTraceSpecParsing benchmarks the trace specification parsing
|
||||
func BenchmarkTraceSpecParsing(b *testing.B) {
|
||||
parser := NewTraceSpecParser()
|
||||
testInput := "sys_open \"opening %s\", arg2@user"
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := parser.ParseFromBCCStyle(testInput)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestSyscallSuggestions tests the syscall suggestion functionality
|
||||
func TestSyscallSuggestions(t *testing.T) {
|
||||
fmt.Println("=== Syscall Suggestion Test ===")
|
||||
fmt.Println()
|
||||
|
||||
testCases := []struct {
|
||||
issue string
|
||||
expected int // minimum expected suggestions
|
||||
description string
|
||||
}{
|
||||
{
|
||||
issue: "file not found error",
|
||||
expected: 1,
|
||||
description: "File I/O issue should suggest file-related syscalls",
|
||||
},
|
||||
{
|
||||
issue: "network connection timeout",
|
||||
expected: 1,
|
||||
description: "Network issue should suggest network syscalls",
|
||||
},
|
||||
{
|
||||
issue: "process crashes randomly",
|
||||
expected: 1,
|
||||
description: "Process issue should suggest process-related syscalls",
|
||||
},
|
||||
{
|
||||
issue: "memory leak detected",
|
||||
expected: 1,
|
||||
description: "Memory issue should suggest memory syscalls",
|
||||
},
|
||||
{
|
||||
issue: "application is slow",
|
||||
expected: 1,
|
||||
description: "Performance issue should suggest monitoring syscalls",
|
||||
},
|
||||
}
|
||||
|
||||
fmt.Printf("💡 Testing syscall suggestions:\n")
|
||||
for _, tc := range testCases {
|
||||
suggestions := SuggestSyscallTargets(tc.issue)
|
||||
fmt.Printf(" Issue: '%s' -> %d suggestions: %v\n",
|
||||
tc.issue, len(suggestions), suggestions)
|
||||
|
||||
if len(suggestions) < tc.expected {
|
||||
t.Errorf("Expected at least %d suggestions for '%s', got %d",
|
||||
tc.expected, tc.issue, len(suggestions))
|
||||
}
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
// TestMain runs the tests and provides a summary
|
||||
func TestMain(m *testing.M) {
|
||||
fmt.Println("🚀 Starting BCC-Style eBPF Tracing Tests")
|
||||
fmt.Println("========================================")
|
||||
fmt.Println()
|
||||
|
||||
// Run capability check first
|
||||
manager := NewBCCTraceManager()
|
||||
caps := manager.GetCapabilities()
|
||||
|
||||
if !caps["kernel_ebpf"] {
|
||||
fmt.Println("⚠️ Kernel eBPF support not detected - some tests may be limited")
|
||||
}
|
||||
if !caps["bpftrace"] {
|
||||
fmt.Println("⚠️ bpftrace not available - install with: sudo apt install bpftrace")
|
||||
}
|
||||
if !caps["root_access"] {
|
||||
fmt.Println("⚠️ Root access required for actual eBPF tracing")
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
|
||||
// Run the tests
|
||||
code := m.Run()
|
||||
|
||||
fmt.Println()
|
||||
fmt.Println("========================================")
|
||||
if code == 0 {
|
||||
fmt.Println("✅ All BCC-Style eBPF Tracing Tests Passed!")
|
||||
} else {
|
||||
fmt.Println("❌ Some tests failed")
|
||||
}
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
// TestBCCTraceManagerRootTest tests the actual BCC trace manager with root privileges
|
||||
// This test requires root access and will only run meaningful tests when root
|
||||
func TestBCCTraceManagerRootTest(t *testing.T) {
|
||||
fmt.Println("=== BCC Trace Manager Root Test ===")
|
||||
|
||||
// Check if running as root
|
||||
if os.Geteuid() != 0 {
|
||||
t.Skip("⚠️ Skipping root test - not running as root (use: sudo go test -run TestBCCTraceManagerRootTest)")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println("✅ Running as root - can test actual eBPF functionality")
|
||||
|
||||
// Test 1: Create BCC trace manager and check capabilities
|
||||
manager := NewBCCTraceManager()
|
||||
caps := manager.GetCapabilities()
|
||||
|
||||
fmt.Printf("🔍 BCC Trace Manager Capabilities:\n")
|
||||
for cap, available := range caps {
|
||||
status := "❌"
|
||||
if available {
|
||||
status = "✅"
|
||||
}
|
||||
fmt.Printf(" %s %s: %v\n", status, cap, available)
|
||||
}
|
||||
|
||||
// Require essential capabilities
|
||||
if !caps["bpftrace"] {
|
||||
t.Fatal("❌ bpftrace not available - install bpftrace package")
|
||||
}
|
||||
|
||||
if !caps["root_access"] {
|
||||
t.Fatal("❌ Root access not detected")
|
||||
}
|
||||
|
||||
// Test 2: Create and execute a simple trace
|
||||
fmt.Println("\n🔬 Testing actual eBPF trace execution...")
|
||||
|
||||
spec := TraceSpec{
|
||||
ProbeType: "t", // tracepoint
|
||||
Target: "syscalls:sys_enter_openat",
|
||||
Format: "file access",
|
||||
Arguments: []string{}, // Remove invalid arg2@user for tracepoints
|
||||
Duration: 3, // 3 seconds
|
||||
}
|
||||
|
||||
fmt.Printf("📝 Starting trace: %s for %d seconds\n", spec.Target, spec.Duration)
|
||||
|
||||
traceID, err := manager.StartTrace(spec)
|
||||
if err != nil {
|
||||
t.Fatalf("❌ Failed to start trace: %v", err)
|
||||
}
|
||||
|
||||
fmt.Printf("🚀 Trace started with ID: %s\n", traceID)
|
||||
|
||||
// Generate some file access to capture
|
||||
go func() {
|
||||
time.Sleep(1 * time.Second)
|
||||
// Create some file operations to trace
|
||||
for i := 0; i < 3; i++ {
|
||||
testFile := fmt.Sprintf("/tmp/bcc_test_%d.txt", i)
|
||||
|
||||
// This will trigger sys_openat syscalls
|
||||
if file, err := os.Create(testFile); err == nil {
|
||||
file.WriteString("BCC trace test")
|
||||
file.Close()
|
||||
os.Remove(testFile)
|
||||
}
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for trace to complete
|
||||
time.Sleep(time.Duration(spec.Duration+1) * time.Second)
|
||||
|
||||
// Get results
|
||||
result, err := manager.GetTraceResult(traceID)
|
||||
if err != nil {
|
||||
// Try to stop the trace if it's still running
|
||||
manager.StopTrace(traceID)
|
||||
t.Fatalf("❌ Failed to get trace results: %v", err)
|
||||
}
|
||||
|
||||
fmt.Printf("\n📊 Trace Results Summary:\n")
|
||||
fmt.Printf(" • Trace ID: %s\n", result.TraceID)
|
||||
fmt.Printf(" • Target: %s\n", result.Spec.Target)
|
||||
fmt.Printf(" • Duration: %v\n", result.EndTime.Sub(result.StartTime))
|
||||
fmt.Printf(" • Events captured: %d\n", result.EventCount)
|
||||
fmt.Printf(" • Events per second: %.2f\n", result.Statistics.EventsPerSecond)
|
||||
fmt.Printf(" • Summary: %s\n", result.Summary)
|
||||
|
||||
if len(result.Events) > 0 {
|
||||
fmt.Printf("\n📝 Sample Events (first 3):\n")
|
||||
for i, event := range result.Events {
|
||||
if i >= 3 {
|
||||
break
|
||||
}
|
||||
fmt.Printf(" %d. PID:%d TID:%d Process:%s Message:%s\n",
|
||||
i+1, event.PID, event.TID, event.ProcessName, event.Message)
|
||||
}
|
||||
|
||||
if len(result.Events) > 3 {
|
||||
fmt.Printf(" ... and %d more events\n", len(result.Events)-3)
|
||||
}
|
||||
}
|
||||
|
||||
// Test 3: Validate the trace produced real data
|
||||
if result.EventCount == 0 {
|
||||
fmt.Println("⚠️ Warning: No events captured - this might be normal for a quiet system")
|
||||
} else {
|
||||
fmt.Printf("✅ Successfully captured %d real eBPF events!\n", result.EventCount)
|
||||
}
|
||||
|
||||
fmt.Println("\n🧪 Testing comprehensive system tracing (Network, Disk, CPU, Memory, Userspace)...")
|
||||
|
||||
testSpecs := []TraceSpec{
|
||||
// === SYSCALL TRACING ===
|
||||
{
|
||||
ProbeType: "p", // kprobe
|
||||
Target: "__x64_sys_write",
|
||||
Format: "write: fd=%d count=%d",
|
||||
Arguments: []string{"arg1", "arg3"},
|
||||
Duration: 2,
|
||||
},
|
||||
{
|
||||
ProbeType: "p", // kprobe
|
||||
Target: "__x64_sys_read",
|
||||
Format: "read: fd=%d count=%d",
|
||||
Arguments: []string{"arg1", "arg3"},
|
||||
Duration: 2,
|
||||
},
|
||||
{
|
||||
ProbeType: "p", // kprobe
|
||||
Target: "__x64_sys_connect",
|
||||
Format: "network connect: fd=%d",
|
||||
Arguments: []string{"arg1"},
|
||||
Duration: 2,
|
||||
},
|
||||
{
|
||||
ProbeType: "p", // kprobe
|
||||
Target: "__x64_sys_accept",
|
||||
Format: "network accept: fd=%d",
|
||||
Arguments: []string{"arg1"},
|
||||
Duration: 2,
|
||||
},
|
||||
// === BLOCK I/O TRACING ===
|
||||
{
|
||||
ProbeType: "t", // tracepoint
|
||||
Target: "block:block_io_start",
|
||||
Format: "block I/O start",
|
||||
Arguments: []string{},
|
||||
Duration: 2,
|
||||
},
|
||||
{
|
||||
ProbeType: "t", // tracepoint
|
||||
Target: "block:block_io_done",
|
||||
Format: "block I/O complete",
|
||||
Arguments: []string{},
|
||||
Duration: 2,
|
||||
},
|
||||
// === CPU SCHEDULER TRACING ===
|
||||
{
|
||||
ProbeType: "t", // tracepoint
|
||||
Target: "sched:sched_migrate_task",
|
||||
Format: "task migration",
|
||||
Arguments: []string{},
|
||||
Duration: 2,
|
||||
},
|
||||
{
|
||||
ProbeType: "t", // tracepoint
|
||||
Target: "sched:sched_pi_setprio",
|
||||
Format: "priority change",
|
||||
Arguments: []string{},
|
||||
Duration: 2,
|
||||
},
|
||||
// === MEMORY MANAGEMENT ===
|
||||
{
|
||||
ProbeType: "t", // tracepoint
|
||||
Target: "syscalls:sys_enter_brk",
|
||||
Format: "memory allocation: brk",
|
||||
Arguments: []string{},
|
||||
Duration: 2,
|
||||
},
|
||||
// === KERNEL MEMORY TRACING ===
|
||||
{
|
||||
ProbeType: "t", // tracepoint
|
||||
Target: "kmem:kfree",
|
||||
Format: "kernel memory free",
|
||||
Arguments: []string{},
|
||||
Duration: 2,
|
||||
},
|
||||
}
|
||||
|
||||
for i, testSpec := range testSpecs {
|
||||
category := "unknown"
|
||||
if strings.Contains(testSpec.Target, "sys_write") || strings.Contains(testSpec.Target, "sys_read") {
|
||||
category = "filesystem"
|
||||
} else if strings.Contains(testSpec.Target, "sys_connect") || strings.Contains(testSpec.Target, "sys_accept") {
|
||||
category = "network"
|
||||
} else if strings.Contains(testSpec.Target, "block:") {
|
||||
category = "disk I/O"
|
||||
} else if strings.Contains(testSpec.Target, "sched:") {
|
||||
category = "CPU/scheduler"
|
||||
} else if strings.Contains(testSpec.Target, "sys_brk") || strings.Contains(testSpec.Target, "kmem:") {
|
||||
category = "memory"
|
||||
}
|
||||
|
||||
fmt.Printf("\n 🔍 Test %d: [%s] Tracing %s for %d seconds\n", i+1, category, testSpec.Target, testSpec.Duration)
|
||||
|
||||
testTraceID, err := manager.StartTrace(testSpec)
|
||||
if err != nil {
|
||||
fmt.Printf(" ❌ Failed to start: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Generate activity specific to this trace type
|
||||
go func(target, probeType string) {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
switch {
|
||||
case strings.Contains(target, "sys_write") || strings.Contains(target, "sys_read"):
|
||||
// Generate file I/O
|
||||
for j := 0; j < 3; j++ {
|
||||
testFile := fmt.Sprintf("/tmp/io_test_%d.txt", j)
|
||||
if file, err := os.Create(testFile); err == nil {
|
||||
file.WriteString("BCC tracing test data for I/O operations")
|
||||
file.Sync()
|
||||
file.Close()
|
||||
|
||||
// Read the file back
|
||||
if readFile, err := os.Open(testFile); err == nil {
|
||||
buffer := make([]byte, 1024)
|
||||
readFile.Read(buffer)
|
||||
readFile.Close()
|
||||
}
|
||||
os.Remove(testFile)
|
||||
}
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
case strings.Contains(target, "block:"):
|
||||
// Generate disk I/O to trigger block layer events
|
||||
for j := 0; j < 3; j++ {
|
||||
testFile := fmt.Sprintf("/tmp/block_test_%d.txt", j)
|
||||
if file, err := os.Create(testFile); err == nil {
|
||||
// Write substantial data to trigger block I/O
|
||||
data := make([]byte, 1024*4) // 4KB
|
||||
for k := range data {
|
||||
data[k] = byte(k % 256)
|
||||
}
|
||||
file.Write(data)
|
||||
file.Sync() // Force write to disk
|
||||
file.Close()
|
||||
}
|
||||
os.Remove(testFile)
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
}
|
||||
case strings.Contains(target, "sched:"):
|
||||
// Generate CPU activity to trigger scheduler events
|
||||
go func() {
|
||||
for j := 0; j < 100; j++ {
|
||||
// Create short-lived goroutines to trigger scheduler activity
|
||||
go func() {
|
||||
time.Sleep(time.Millisecond * 1)
|
||||
}()
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
}
|
||||
}()
|
||||
case strings.Contains(target, "sys_brk") || strings.Contains(target, "kmem:"):
|
||||
// Generate memory allocation activity
|
||||
for j := 0; j < 5; j++ {
|
||||
// Allocate and free memory to trigger memory management
|
||||
data := make([]byte, 1024*1024) // 1MB
|
||||
for k := range data {
|
||||
data[k] = byte(k % 256)
|
||||
}
|
||||
data = nil // Allow GC
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
case strings.Contains(target, "sys_connect") || strings.Contains(target, "sys_accept"):
|
||||
// Network operations (these may not generate events in test environment)
|
||||
fmt.Printf(" Note: Network syscalls may not trigger events without actual network activity\n")
|
||||
default:
|
||||
// Generic activity
|
||||
for j := 0; j < 3; j++ {
|
||||
testFile := fmt.Sprintf("/tmp/generic_test_%d.txt", j)
|
||||
if file, err := os.Create(testFile); err == nil {
|
||||
file.WriteString("Generic test activity")
|
||||
file.Close()
|
||||
}
|
||||
os.Remove(testFile)
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
}(testSpec.Target, testSpec.ProbeType)
|
||||
|
||||
// Wait for trace completion
|
||||
time.Sleep(time.Duration(testSpec.Duration+1) * time.Second)
|
||||
|
||||
testResult, err := manager.GetTraceResult(testTraceID)
|
||||
if err != nil {
|
||||
manager.StopTrace(testTraceID)
|
||||
fmt.Printf(" ⚠️ Result error: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
fmt.Printf(" 📊 Results for %s:\n", testSpec.Target)
|
||||
fmt.Printf(" • Total events: %d\n", testResult.EventCount)
|
||||
fmt.Printf(" • Events/sec: %.2f\n", testResult.Statistics.EventsPerSecond)
|
||||
fmt.Printf(" • Duration: %v\n", testResult.EndTime.Sub(testResult.StartTime))
|
||||
|
||||
// Show process breakdown
|
||||
if len(testResult.Statistics.TopProcesses) > 0 {
|
||||
fmt.Printf(" • Top processes:\n")
|
||||
for j, proc := range testResult.Statistics.TopProcesses {
|
||||
if j >= 3 { // Show top 3
|
||||
break
|
||||
}
|
||||
fmt.Printf(" - %s: %d events (%.1f%%)\n",
|
||||
proc.ProcessName, proc.EventCount, proc.Percentage)
|
||||
}
|
||||
}
|
||||
|
||||
// Show sample events with PIDs, counts, etc.
|
||||
if len(testResult.Events) > 0 {
|
||||
fmt.Printf(" • Sample events:\n")
|
||||
for j, event := range testResult.Events {
|
||||
if j >= 5 { // Show first 5 events
|
||||
break
|
||||
}
|
||||
fmt.Printf(" [%d] PID:%d TID:%d Process:%s Message:%s\n",
|
||||
j+1, event.PID, event.TID, event.ProcessName, event.Message)
|
||||
}
|
||||
if len(testResult.Events) > 5 {
|
||||
fmt.Printf(" ... and %d more events\n", len(testResult.Events)-5)
|
||||
}
|
||||
}
|
||||
|
||||
if testResult.EventCount > 0 {
|
||||
fmt.Printf(" ✅ Success: Captured %d real syscall events!\n", testResult.EventCount)
|
||||
} else {
|
||||
fmt.Printf(" ⚠️ No events captured (may be normal for this syscall)\n")
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println("\n🎉 BCC Trace Manager Root Test Complete!")
|
||||
fmt.Println("✅ Real eBPF tracing is working and ready for production use!")
|
||||
}
|
||||
|
||||
// TestAgentEBPFIntegration tests the agent's integration with BCC-style eBPF tracing
|
||||
// This demonstrates the complete flow from agent to eBPF results
|
||||
func TestAgentEBPFIntegration(t *testing.T) {
|
||||
if os.Geteuid() != 0 {
|
||||
t.Skip("⚠️ Skipping agent integration test - requires root access")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println("\n=== Agent eBPF Integration Test ===")
|
||||
fmt.Println("This test demonstrates the complete agent flow with BCC-style tracing")
|
||||
|
||||
// Create eBPF manager directly for testing
|
||||
manager := NewBCCTraceManager()
|
||||
|
||||
// Test multiple syscalls that would be sent by remote API
|
||||
testEBPFRequests := []struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
Target string `json:"target"`
|
||||
Duration int `json:"duration"`
|
||||
Description string `json:"description"`
|
||||
Filters map[string]string `json:"filters"`
|
||||
}{
|
||||
{
|
||||
Name: "file_operations",
|
||||
Type: "syscall",
|
||||
Target: "sys_openat", // Will be converted to __x64_sys_openat
|
||||
Duration: 3,
|
||||
Description: "trace file open operations",
|
||||
Filters: map[string]string{},
|
||||
},
|
||||
{
|
||||
Name: "network_operations",
|
||||
Type: "syscall",
|
||||
Target: "__x64_sys_connect",
|
||||
Duration: 2,
|
||||
Description: "trace network connections",
|
||||
Filters: map[string]string{},
|
||||
},
|
||||
{
|
||||
Name: "io_operations",
|
||||
Type: "syscall",
|
||||
Target: "sys_write",
|
||||
Duration: 2,
|
||||
Description: "trace write operations",
|
||||
Filters: map[string]string{},
|
||||
},
|
||||
}
|
||||
|
||||
fmt.Printf("🚀 Testing eBPF manager with %d eBPF programs...\n\n", len(testEBPFRequests))
|
||||
|
||||
// Convert to trace specs and execute using manager directly
|
||||
var traceSpecs []TraceSpec
|
||||
for _, req := range testEBPFRequests {
|
||||
spec := TraceSpec{
|
||||
ProbeType: "p", // kprobe
|
||||
Target: "__x64_" + req.Target,
|
||||
Format: req.Description,
|
||||
Duration: req.Duration,
|
||||
}
|
||||
traceSpecs = append(traceSpecs, spec)
|
||||
}
|
||||
|
||||
// Execute traces sequentially for testing
|
||||
var results []map[string]interface{}
|
||||
for i, spec := range traceSpecs {
|
||||
fmt.Printf("Starting trace %d: %s\n", i+1, spec.Target)
|
||||
|
||||
traceID, err := manager.StartTrace(spec)
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to start trace: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Wait for trace duration
|
||||
time.Sleep(time.Duration(spec.Duration) * time.Second)
|
||||
|
||||
traceResult, err := manager.GetTraceResult(traceID)
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to get results: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
result := map[string]interface{}{
|
||||
"name": testEBPFRequests[i].Name,
|
||||
"target": spec.Target,
|
||||
"success": true,
|
||||
"event_count": traceResult.EventCount,
|
||||
"summary": traceResult.Summary,
|
||||
}
|
||||
results = append(results, result)
|
||||
}
|
||||
|
||||
fmt.Printf("📊 Agent eBPF Execution Results:\n")
|
||||
fmt.Printf("=" + strings.Repeat("=", 50) + "\n\n")
|
||||
|
||||
for i, result := range results {
|
||||
fmt.Printf("🔍 Program %d: %s\n", i+1, result["name"])
|
||||
fmt.Printf(" Target: %s\n", result["target"])
|
||||
fmt.Printf(" Type: %s\n", result["type"])
|
||||
fmt.Printf(" Status: %s\n", result["status"])
|
||||
fmt.Printf(" Success: %v\n", result["success"])
|
||||
|
||||
if result["success"].(bool) {
|
||||
if eventCount, ok := result["event_count"].(int); ok {
|
||||
fmt.Printf(" Events captured: %d\n", eventCount)
|
||||
}
|
||||
if dataPoints, ok := result["data_points"].(int); ok {
|
||||
fmt.Printf(" Data points: %d\n", dataPoints)
|
||||
}
|
||||
if summary, ok := result["summary"].(string); ok {
|
||||
fmt.Printf(" Summary: %s\n", summary)
|
||||
}
|
||||
|
||||
// Show events if available
|
||||
if events, ok := result["events"].([]TraceEvent); ok && len(events) > 0 {
|
||||
fmt.Printf(" Sample events:\n")
|
||||
for j, event := range events {
|
||||
if j >= 3 { // Show first 3
|
||||
break
|
||||
}
|
||||
fmt.Printf(" [%d] PID:%d Process:%s Message:%s\n",
|
||||
j+1, event.PID, event.ProcessName, event.Message)
|
||||
}
|
||||
if len(events) > 3 {
|
||||
fmt.Printf(" ... and %d more events\n", len(events)-3)
|
||||
}
|
||||
}
|
||||
|
||||
// Show statistics if available
|
||||
if stats, ok := result["statistics"].(TraceStats); ok {
|
||||
fmt.Printf(" Statistics:\n")
|
||||
fmt.Printf(" - Events/sec: %.2f\n", stats.EventsPerSecond)
|
||||
fmt.Printf(" - Total processes: %d\n", len(stats.EventsByProcess))
|
||||
if len(stats.TopProcesses) > 0 {
|
||||
fmt.Printf(" - Top process: %s (%d events)\n",
|
||||
stats.TopProcesses[0].ProcessName, stats.TopProcesses[0].EventCount)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if errMsg, ok := result["error"].(string); ok {
|
||||
fmt.Printf(" Error: %s\n", errMsg)
|
||||
}
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
// Validate expected agent response format
|
||||
t.Run("ValidateAgentResponseFormat", func(t *testing.T) {
|
||||
for i, result := range results {
|
||||
// Check required fields
|
||||
requiredFields := []string{"name", "type", "target", "duration", "description", "status", "success"}
|
||||
for _, field := range requiredFields {
|
||||
if _, exists := result[field]; !exists {
|
||||
t.Errorf("Result %d missing required field: %s", i, field)
|
||||
}
|
||||
}
|
||||
|
||||
// If successful, check for data fields
|
||||
if success, ok := result["success"].(bool); ok && success {
|
||||
// Should have either event_count or data_points
|
||||
hasEventCount := false
|
||||
hasDataPoints := false
|
||||
|
||||
if _, ok := result["event_count"]; ok {
|
||||
hasEventCount = true
|
||||
}
|
||||
if _, ok := result["data_points"]; ok {
|
||||
hasDataPoints = true
|
||||
}
|
||||
|
||||
if !hasEventCount && !hasDataPoints {
|
||||
t.Errorf("Successful result %d should have event_count or data_points", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
fmt.Println("✅ Agent eBPF Integration Test Complete!")
|
||||
fmt.Println("📈 The agent correctly processes eBPF requests and returns detailed syscall data!")
|
||||
}
|
||||
110
internal/executor/executor.go
Normal file
110
internal/executor/executor.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"nannyagentv2/internal/types"
|
||||
)
|
||||
|
||||
// CommandExecutor handles safe execution of diagnostic commands
|
||||
type CommandExecutor struct {
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
// NewCommandExecutor creates a new command executor with specified timeout
|
||||
func NewCommandExecutor(timeout time.Duration) *CommandExecutor {
|
||||
return &CommandExecutor{
|
||||
timeout: timeout,
|
||||
}
|
||||
}
|
||||
|
||||
// Execute executes a command safely with timeout and validation
|
||||
func (ce *CommandExecutor) Execute(cmd types.Command) types.CommandResult {
|
||||
result := types.CommandResult{
|
||||
ID: cmd.ID,
|
||||
Command: cmd.Command,
|
||||
}
|
||||
|
||||
// Validate command safety
|
||||
if err := ce.validateCommand(cmd.Command); err != nil {
|
||||
result.Error = fmt.Sprintf("unsafe command: %s", err.Error())
|
||||
result.ExitCode = 1
|
||||
return result
|
||||
}
|
||||
|
||||
// Create context with timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), ce.timeout)
|
||||
defer cancel()
|
||||
|
||||
// Execute command using shell for proper handling of pipes, redirects, etc.
|
||||
execCmd := exec.CommandContext(ctx, "/bin/bash", "-c", cmd.Command)
|
||||
|
||||
output, err := execCmd.CombinedOutput()
|
||||
result.Output = string(output)
|
||||
|
||||
if err != nil {
|
||||
result.Error = err.Error()
|
||||
if exitError, ok := err.(*exec.ExitError); ok {
|
||||
result.ExitCode = exitError.ExitCode()
|
||||
} else {
|
||||
result.ExitCode = 1
|
||||
}
|
||||
} else {
|
||||
result.ExitCode = 0
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// validateCommand checks if a command is safe to execute
|
||||
func (ce *CommandExecutor) validateCommand(command string) error {
|
||||
// Convert to lowercase for case-insensitive checking
|
||||
cmd := strings.ToLower(strings.TrimSpace(command))
|
||||
|
||||
// List of dangerous commands/patterns
|
||||
dangerousPatterns := []string{
|
||||
"rm ", "rm\t", "rm\n",
|
||||
"mv ", "mv\t", "mv\n",
|
||||
"dd ", "dd\t", "dd\n",
|
||||
"mkfs", "fdisk", "parted",
|
||||
"shutdown", "reboot", "halt", "poweroff",
|
||||
"passwd", "userdel", "usermod",
|
||||
"chmod", "chown", "chgrp",
|
||||
"systemctl stop", "systemctl disable", "systemctl mask",
|
||||
"service stop", "service disable",
|
||||
"kill ", "killall", "pkill",
|
||||
"crontab -r", "crontab -e",
|
||||
"iptables -F", "iptables -D", "iptables -I",
|
||||
"umount ", "unmount ", // Allow mount but not umount
|
||||
"wget ", "curl ", // Prevent network operations
|
||||
"| dd", "| rm", "| mv", // Prevent piping to dangerous commands
|
||||
}
|
||||
|
||||
// Check for dangerous patterns
|
||||
for _, pattern := range dangerousPatterns {
|
||||
if strings.Contains(cmd, pattern) {
|
||||
return fmt.Errorf("command contains dangerous pattern: %s", pattern)
|
||||
}
|
||||
}
|
||||
|
||||
// Additional checks for commands that start with dangerous operations
|
||||
if strings.HasPrefix(cmd, "rm ") || strings.HasPrefix(cmd, "rm\t") {
|
||||
return fmt.Errorf("rm command not allowed")
|
||||
}
|
||||
|
||||
// Check for sudo usage (we want to avoid automated sudo commands)
|
||||
if strings.HasPrefix(cmd, "sudo ") {
|
||||
return fmt.Errorf("sudo commands not allowed for automated execution")
|
||||
}
|
||||
|
||||
// Check for dangerous redirections (but allow safe ones like 2>/dev/null)
|
||||
if strings.Contains(cmd, ">") && !strings.Contains(cmd, "2>/dev/null") && !strings.Contains(cmd, ">/dev/null") {
|
||||
return fmt.Errorf("file redirection not allowed except to /dev/null")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
529
internal/server/investigation_server.go
Normal file
529
internal/server/investigation_server.go
Normal file
@@ -0,0 +1,529 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"nannyagentv2/internal/auth"
|
||||
"nannyagentv2/internal/logging"
|
||||
"nannyagentv2/internal/metrics"
|
||||
"nannyagentv2/internal/types"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
// InvestigationRequest represents a request from Supabase to start an investigation
|
||||
type InvestigationRequest struct {
|
||||
InvestigationID string `json:"investigation_id"`
|
||||
ApplicationGroup string `json:"application_group"`
|
||||
Issue string `json:"issue"`
|
||||
Context map[string]string `json:"context"`
|
||||
Priority string `json:"priority"`
|
||||
InitiatedBy string `json:"initiated_by"`
|
||||
}
|
||||
|
||||
// InvestigationResponse represents the agent's response to an investigation
|
||||
type InvestigationResponse struct {
|
||||
AgentID string `json:"agent_id"`
|
||||
InvestigationID string `json:"investigation_id"`
|
||||
Status string `json:"status"`
|
||||
Commands []types.CommandResult `json:"commands,omitempty"`
|
||||
AIResponse string `json:"ai_response,omitempty"`
|
||||
EpisodeID string `json:"episode_id,omitempty"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// InvestigationServer handles reverse investigation requests from Supabase
|
||||
type InvestigationServer struct {
|
||||
agent types.DiagnosticAgent // Original agent for direct user interactions
|
||||
applicationAgent types.DiagnosticAgent // Separate agent for application-initiated investigations
|
||||
port string
|
||||
agentID string
|
||||
metricsCollector *metrics.Collector
|
||||
authManager *auth.AuthManager
|
||||
startTime time.Time
|
||||
supabaseURL string
|
||||
}
|
||||
|
||||
// NewInvestigationServer creates a new investigation server
|
||||
func NewInvestigationServer(agent types.DiagnosticAgent, authManager *auth.AuthManager) *InvestigationServer {
|
||||
port := os.Getenv("AGENT_PORT")
|
||||
if port == "" {
|
||||
port = "1234"
|
||||
}
|
||||
|
||||
// Get agent ID from authentication system
|
||||
var agentID string
|
||||
if authManager != nil {
|
||||
if id, err := authManager.GetCurrentAgentID(); err == nil {
|
||||
agentID = id
|
||||
|
||||
} else {
|
||||
logging.Error("Failed to get agent ID from auth manager: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to environment variable or generate one if auth fails
|
||||
if agentID == "" {
|
||||
agentID = os.Getenv("AGENT_ID")
|
||||
if agentID == "" {
|
||||
agentID = fmt.Sprintf("agent-%d", time.Now().Unix())
|
||||
}
|
||||
}
|
||||
|
||||
// Create metrics collector
|
||||
metricsCollector := metrics.NewCollector("v2.0.0")
|
||||
|
||||
// TODO: Fix application agent creation - use main agent for now
|
||||
// Create a separate agent for application-initiated investigations
|
||||
// applicationAgent := NewLinuxDiagnosticAgent()
|
||||
// Override the model to use the application-specific function
|
||||
// applicationAgent.model = "tensorzero::function_name::diagnose_and_heal_application"
|
||||
|
||||
return &InvestigationServer{
|
||||
agent: agent,
|
||||
applicationAgent: agent, // Use same agent for now
|
||||
port: port,
|
||||
agentID: agentID,
|
||||
metricsCollector: metricsCollector,
|
||||
authManager: authManager,
|
||||
startTime: time.Now(),
|
||||
supabaseURL: os.Getenv("SUPABASE_PROJECT_URL"),
|
||||
}
|
||||
}
|
||||
|
||||
// DiagnoseIssueForApplication handles diagnostic requests initiated from application/portal
|
||||
func (s *InvestigationServer) DiagnoseIssueForApplication(issue, episodeID string) error {
|
||||
// Set the episode ID on the application agent for continuity
|
||||
// TODO: Fix episode ID handling with interface
|
||||
// s.applicationAgent.episodeID = episodeID
|
||||
return s.applicationAgent.DiagnoseIssue(issue)
|
||||
}
|
||||
|
||||
// Start starts the HTTP server and realtime polling for investigation requests
|
||||
func (s *InvestigationServer) Start() error {
|
||||
mux := http.NewServeMux()
|
||||
|
||||
// Health check endpoint
|
||||
mux.HandleFunc("/health", s.handleHealth)
|
||||
|
||||
// Investigation endpoint
|
||||
mux.HandleFunc("/investigate", s.handleInvestigation)
|
||||
|
||||
// Agent status endpoint
|
||||
mux.HandleFunc("/status", s.handleStatus)
|
||||
|
||||
// Start realtime polling for backend-initiated investigations
|
||||
if s.supabaseURL != "" && s.authManager != nil {
|
||||
go s.startRealtimePolling()
|
||||
logging.Info("Realtime investigation polling enabled")
|
||||
} else {
|
||||
logging.Warning("Realtime investigation polling disabled (missing Supabase config or auth)")
|
||||
}
|
||||
|
||||
server := &http.Server{
|
||||
Addr: ":" + s.port,
|
||||
Handler: mux,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
logging.Info("Investigation server started on port %s (Agent ID: %s)", s.port, s.agentID)
|
||||
return server.ListenAndServe()
|
||||
}
|
||||
|
||||
// handleHealth responds to health check requests
|
||||
func (s *InvestigationServer) handleHealth(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
response := map[string]interface{}{
|
||||
"status": "healthy",
|
||||
"agent_id": s.agentID,
|
||||
"timestamp": time.Now(),
|
||||
"version": "v2.0.0",
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
// handleStatus responds with agent status and capabilities
|
||||
func (s *InvestigationServer) handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// Collect current system metrics
|
||||
systemMetrics, err := s.metricsCollector.GatherSystemMetrics()
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to collect metrics: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Convert to metrics request format for consistent data structure
|
||||
metricsReq := s.metricsCollector.CreateMetricsRequest(s.agentID, systemMetrics)
|
||||
|
||||
response := map[string]interface{}{
|
||||
"agent_id": s.agentID,
|
||||
"status": "ready",
|
||||
"capabilities": []string{"system_diagnostics", "ebpf_monitoring", "command_execution", "ai_analysis"},
|
||||
"system_info": map[string]interface{}{
|
||||
"os": fmt.Sprintf("%s %s", metricsReq.OSInfo["platform"], metricsReq.OSInfo["platform_version"]),
|
||||
"kernel": metricsReq.KernelVersion,
|
||||
"architecture": metricsReq.OSInfo["kernel_arch"],
|
||||
"cpu_cores": metricsReq.OSInfo["cpu_cores"],
|
||||
"memory": metricsReq.MemoryUsage,
|
||||
"private_ips": metricsReq.IPAddress,
|
||||
"load_average": fmt.Sprintf("%.2f, %.2f, %.2f",
|
||||
metricsReq.LoadAverages["load1"],
|
||||
metricsReq.LoadAverages["load5"],
|
||||
metricsReq.LoadAverages["load15"]),
|
||||
"disk_usage": fmt.Sprintf("Root: %.0fG/%.0fG (%.0f%% used)",
|
||||
float64(metricsReq.FilesystemInfo[0].Used)/1024/1024/1024,
|
||||
float64(metricsReq.FilesystemInfo[0].Total)/1024/1024/1024,
|
||||
metricsReq.DiskUsage),
|
||||
},
|
||||
"uptime": time.Since(s.startTime),
|
||||
"last_contact": time.Now(),
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
// sendCommandResultsToTensorZero sends command results back to TensorZero and continues conversation
|
||||
func (s *InvestigationServer) sendCommandResultsToTensorZero(diagnosticResp types.DiagnosticResponse, commandResults []types.CommandResult) (interface{}, error) {
|
||||
// Build conversation history like in agent.go
|
||||
messages := []openai.ChatCompletionMessage{
|
||||
// Add the original diagnostic response as assistant message
|
||||
{
|
||||
Role: openai.ChatMessageRoleAssistant,
|
||||
Content: fmt.Sprintf(`{"response_type":"diagnostic","reasoning":"%s","commands":%s}`,
|
||||
diagnosticResp.Reasoning,
|
||||
mustMarshalJSON(diagnosticResp.Commands)),
|
||||
},
|
||||
}
|
||||
|
||||
// Add command results as user message (same as agent.go does)
|
||||
resultsJSON, err := json.MarshalIndent(commandResults, "", " ")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal command results: %w", err)
|
||||
}
|
||||
|
||||
messages = append(messages, openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: string(resultsJSON),
|
||||
})
|
||||
|
||||
// Send to TensorZero via application agent's sendRequest method
|
||||
logging.Debug("Sending command results to TensorZero for analysis")
|
||||
response, err := s.applicationAgent.SendRequest(messages)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request to TensorZero: %w", err)
|
||||
}
|
||||
|
||||
if len(response.Choices) == 0 {
|
||||
return nil, fmt.Errorf("no choices in TensorZero response")
|
||||
}
|
||||
|
||||
content := response.Choices[0].Message.Content
|
||||
logging.Debug("TensorZero continued analysis: %s", content)
|
||||
|
||||
// Try to parse the response to determine if it's diagnostic or resolution
|
||||
var diagnosticNextResp types.DiagnosticResponse
|
||||
var resolutionResp types.ResolutionResponse
|
||||
|
||||
// Check if it's another diagnostic response
|
||||
if err := json.Unmarshal([]byte(content), &diagnosticNextResp); err == nil && diagnosticNextResp.ResponseType == "diagnostic" {
|
||||
logging.Debug("TensorZero requests %d more commands", len(diagnosticNextResp.Commands))
|
||||
return map[string]interface{}{
|
||||
"type": "diagnostic",
|
||||
"response": diagnosticNextResp,
|
||||
"raw": content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Check if it's a resolution response
|
||||
if err := json.Unmarshal([]byte(content), &resolutionResp); err == nil && resolutionResp.ResponseType == "resolution" {
|
||||
|
||||
return map[string]interface{}{
|
||||
"type": "resolution",
|
||||
"response": resolutionResp,
|
||||
"raw": content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Return raw response if we can't parse it
|
||||
return map[string]interface{}{
|
||||
"type": "unknown",
|
||||
"raw": content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Helper function to marshal JSON without errors
|
||||
func mustMarshalJSON(v interface{}) string {
|
||||
data, _ := json.Marshal(v)
|
||||
return string(data)
|
||||
}
|
||||
|
||||
// processInvestigation handles the actual investigation using TensorZero
|
||||
// This endpoint receives either:
|
||||
// 1. DiagnosticResponse - Commands and eBPF programs to execute
|
||||
// 2. ResolutionResponse - Final resolution (no execution needed)
|
||||
func (s *InvestigationServer) handleInvestigation(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed - only POST accepted", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse the request body to determine what type of response this is
|
||||
var requestBody map[string]interface{}
|
||||
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Invalid JSON: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Check the response_type field to determine how to handle this
|
||||
responseType, ok := requestBody["response_type"].(string)
|
||||
if !ok {
|
||||
http.Error(w, "Missing or invalid response_type field", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
logging.Debug("Received investigation payload with response_type: %s", responseType)
|
||||
|
||||
switch responseType {
|
||||
case "diagnostic":
|
||||
// This is a DiagnosticResponse with commands to execute
|
||||
response := s.handleDiagnosticExecution(requestBody)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(response)
|
||||
|
||||
case "resolution":
|
||||
// This is a ResolutionResponse - final result, just acknowledge
|
||||
fmt.Printf("📋 Received final resolution from backend\n")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"success": true,
|
||||
"message": "Resolution received and acknowledged",
|
||||
"agent_id": s.agentID,
|
||||
})
|
||||
|
||||
default:
|
||||
http.Error(w, fmt.Sprintf("Unknown response_type: %s", responseType), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// handleDiagnosticExecution executes commands from a DiagnosticResponse
|
||||
func (s *InvestigationServer) handleDiagnosticExecution(requestBody map[string]interface{}) map[string]interface{} {
|
||||
// Parse as DiagnosticResponse
|
||||
var diagnosticResp types.DiagnosticResponse
|
||||
|
||||
// Convert the map back to JSON and then parse it properly
|
||||
jsonData, err := json.Marshal(requestBody)
|
||||
if err != nil {
|
||||
return map[string]interface{}{
|
||||
"success": false,
|
||||
"error": fmt.Sprintf("Failed to re-marshal request: %v", err),
|
||||
"agent_id": s.agentID,
|
||||
}
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(jsonData, &diagnosticResp); err != nil {
|
||||
return map[string]interface{}{
|
||||
"success": false,
|
||||
"error": fmt.Sprintf("Failed to parse DiagnosticResponse: %v", err),
|
||||
"agent_id": s.agentID,
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Printf("📋 Executing %d commands from backend\n", len(diagnosticResp.Commands))
|
||||
|
||||
// Execute all commands
|
||||
commandResults := make([]types.CommandResult, 0, len(diagnosticResp.Commands))
|
||||
|
||||
for _, cmd := range diagnosticResp.Commands {
|
||||
fmt.Printf("⚙️ Executing command '%s': %s\n", cmd.ID, cmd.Command)
|
||||
|
||||
// Use the agent's executor to run the command
|
||||
result := s.agent.ExecuteCommand(cmd)
|
||||
commandResults = append(commandResults, result)
|
||||
|
||||
if result.Error != "" {
|
||||
fmt.Printf("⚠️ Command '%s' had error: %s\n", cmd.ID, result.Error)
|
||||
}
|
||||
}
|
||||
|
||||
// Send command results back to TensorZero for continued analysis
|
||||
fmt.Printf("🔄 Sending %d command results back to TensorZero for continued analysis\n", len(commandResults))
|
||||
|
||||
nextResponse, err := s.sendCommandResultsToTensorZero(diagnosticResp, commandResults)
|
||||
if err != nil {
|
||||
return map[string]interface{}{
|
||||
"success": false,
|
||||
"error": fmt.Sprintf("Failed to continue TensorZero conversation: %v", err),
|
||||
"agent_id": s.agentID,
|
||||
"command_results": commandResults, // Still return the results
|
||||
}
|
||||
}
|
||||
|
||||
// Return both the command results and the next response from TensorZero
|
||||
return map[string]interface{}{
|
||||
"success": true,
|
||||
"agent_id": s.agentID,
|
||||
"command_results": commandResults,
|
||||
"commands_executed": len(commandResults),
|
||||
"next_response": nextResponse,
|
||||
"timestamp": time.Now().Format(time.RFC3339),
|
||||
}
|
||||
}
|
||||
|
||||
// PendingInvestigation represents a pending investigation from the database
|
||||
type PendingInvestigation struct {
|
||||
ID string `json:"id"`
|
||||
InvestigationID string `json:"investigation_id"`
|
||||
AgentID string `json:"agent_id"`
|
||||
DiagnosticPayload map[string]interface{} `json:"diagnostic_payload"`
|
||||
EpisodeID *string `json:"episode_id"`
|
||||
Status string `json:"status"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// startRealtimePolling begins polling for pending investigations
|
||||
func (s *InvestigationServer) startRealtimePolling() {
|
||||
fmt.Printf("🔄 Starting realtime investigation polling for agent %s\n", s.agentID)
|
||||
|
||||
ticker := time.NewTicker(5 * time.Second) // Poll every 5 seconds
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
s.checkForPendingInvestigations()
|
||||
}
|
||||
}
|
||||
|
||||
// checkForPendingInvestigations checks for new pending investigations
|
||||
func (s *InvestigationServer) checkForPendingInvestigations() {
|
||||
url := fmt.Sprintf("%s/rest/v1/pending_investigations?agent_id=eq.%s&status=eq.pending&order=created_at.desc",
|
||||
s.supabaseURL, s.agentID)
|
||||
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
return // Silent fail for polling
|
||||
}
|
||||
|
||||
// Get token from auth manager
|
||||
authToken, err := s.authManager.LoadToken()
|
||||
if err != nil {
|
||||
return // Silent fail for polling
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", authToken.AccessToken))
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return // Silent fail for polling
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return // Silent fail for polling
|
||||
}
|
||||
|
||||
var investigations []PendingInvestigation
|
||||
err = json.NewDecoder(resp.Body).Decode(&investigations)
|
||||
if err != nil {
|
||||
return // Silent fail for polling
|
||||
}
|
||||
|
||||
for _, investigation := range investigations {
|
||||
fmt.Printf("🔍 Found pending investigation: %s\n", investigation.ID)
|
||||
go s.handlePendingInvestigation(investigation)
|
||||
}
|
||||
}
|
||||
|
||||
// handlePendingInvestigation processes a single pending investigation
|
||||
func (s *InvestigationServer) handlePendingInvestigation(investigation PendingInvestigation) {
|
||||
fmt.Printf("🚀 Processing realtime investigation %s\n", investigation.InvestigationID)
|
||||
|
||||
// Mark as executing
|
||||
err := s.updateInvestigationStatus(investigation.ID, "executing", nil, nil)
|
||||
if err != nil {
|
||||
fmt.Printf("❌ Failed to mark investigation as executing: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Execute diagnostic commands using existing handleDiagnosticExecution method
|
||||
results := s.handleDiagnosticExecution(investigation.DiagnosticPayload)
|
||||
|
||||
// Mark as completed with results
|
||||
err = s.updateInvestigationStatus(investigation.ID, "completed", results, nil)
|
||||
if err != nil {
|
||||
fmt.Printf("❌ Failed to mark investigation as completed: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// updateInvestigationStatus updates the status of a pending investigation
|
||||
func (s *InvestigationServer) updateInvestigationStatus(id, status string, results map[string]interface{}, errorMsg *string) error {
|
||||
updateData := map[string]interface{}{
|
||||
"status": status,
|
||||
}
|
||||
|
||||
if status == "executing" {
|
||||
updateData["started_at"] = time.Now().UTC().Format(time.RFC3339)
|
||||
} else if status == "completed" {
|
||||
updateData["completed_at"] = time.Now().UTC().Format(time.RFC3339)
|
||||
if results != nil {
|
||||
updateData["command_results"] = results
|
||||
}
|
||||
} else if status == "failed" && errorMsg != nil {
|
||||
updateData["error_message"] = *errorMsg
|
||||
updateData["completed_at"] = time.Now().UTC().Format(time.RFC3339)
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(updateData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal update data: %v", err)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/rest/v1/pending_investigations?id=eq.%s", s.supabaseURL, id)
|
||||
req, err := http.NewRequest("PATCH", url, strings.NewReader(string(jsonData)))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %v", err)
|
||||
}
|
||||
|
||||
// Get token from auth manager
|
||||
authToken, err := s.authManager.LoadToken()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load auth token: %v", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", authToken.AccessToken))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update investigation: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 && resp.StatusCode != 204 {
|
||||
return fmt.Errorf("supabase update error: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
157
internal/system/system_info.go
Normal file
157
internal/system/system_info.go
Normal file
@@ -0,0 +1,157 @@
|
||||
package system
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"nannyagentv2/internal/executor"
|
||||
"nannyagentv2/internal/types"
|
||||
)
|
||||
|
||||
// SystemInfo represents basic system information
|
||||
type SystemInfo struct {
|
||||
Hostname string `json:"hostname"`
|
||||
OS string `json:"os"`
|
||||
Kernel string `json:"kernel"`
|
||||
Architecture string `json:"architecture"`
|
||||
CPUCores string `json:"cpu_cores"`
|
||||
Memory string `json:"memory"`
|
||||
Uptime string `json:"uptime"`
|
||||
PrivateIPs string `json:"private_ips"`
|
||||
LoadAverage string `json:"load_average"`
|
||||
DiskUsage string `json:"disk_usage"`
|
||||
}
|
||||
|
||||
// GatherSystemInfo collects basic system information
|
||||
func GatherSystemInfo() *SystemInfo {
|
||||
info := &SystemInfo{}
|
||||
executor := executor.NewCommandExecutor(5 * time.Second)
|
||||
|
||||
// Basic system info
|
||||
if result := executor.Execute(types.Command{ID: "hostname", Command: "hostname"}); result.ExitCode == 0 {
|
||||
info.Hostname = strings.TrimSpace(result.Output)
|
||||
}
|
||||
|
||||
if result := executor.Execute(types.Command{ID: "os", Command: "lsb_release -d 2>/dev/null | cut -f2 || cat /etc/os-release | grep PRETTY_NAME | cut -d'=' -f2 | tr -d '\"'"}); result.ExitCode == 0 {
|
||||
info.OS = strings.TrimSpace(result.Output)
|
||||
}
|
||||
|
||||
if result := executor.Execute(types.Command{ID: "kernel", Command: "uname -r"}); result.ExitCode == 0 {
|
||||
info.Kernel = strings.TrimSpace(result.Output)
|
||||
}
|
||||
|
||||
if result := executor.Execute(types.Command{ID: "arch", Command: "uname -m"}); result.ExitCode == 0 {
|
||||
info.Architecture = strings.TrimSpace(result.Output)
|
||||
}
|
||||
|
||||
if result := executor.Execute(types.Command{ID: "cores", Command: "nproc"}); result.ExitCode == 0 {
|
||||
info.CPUCores = strings.TrimSpace(result.Output)
|
||||
}
|
||||
|
||||
if result := executor.Execute(types.Command{ID: "memory", Command: "free -h | grep Mem | awk '{print $2}'"}); result.ExitCode == 0 {
|
||||
info.Memory = strings.TrimSpace(result.Output)
|
||||
}
|
||||
|
||||
if result := executor.Execute(types.Command{ID: "uptime", Command: "uptime -p"}); result.ExitCode == 0 {
|
||||
info.Uptime = strings.TrimSpace(result.Output)
|
||||
}
|
||||
|
||||
if result := executor.Execute(types.Command{ID: "load", Command: "uptime | awk -F'load average:' '{print $2}' | xargs"}); result.ExitCode == 0 {
|
||||
info.LoadAverage = strings.TrimSpace(result.Output)
|
||||
}
|
||||
|
||||
if result := executor.Execute(types.Command{ID: "disk", Command: "df -h / | tail -1 | awk '{print \"Root: \" $3 \"/\" $2 \" (\" $5 \" used)\"}'"}); result.ExitCode == 0 {
|
||||
info.DiskUsage = strings.TrimSpace(result.Output)
|
||||
}
|
||||
|
||||
// Get private IP addresses
|
||||
info.PrivateIPs = getPrivateIPs()
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
// getPrivateIPs returns private IP addresses
|
||||
func getPrivateIPs() string {
|
||||
var privateIPs []string
|
||||
|
||||
interfaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return "Unable to determine"
|
||||
}
|
||||
|
||||
for _, iface := range interfaces {
|
||||
if iface.Flags&net.FlagUp == 0 || iface.Flags&net.FlagLoopback != 0 {
|
||||
continue // Skip down or loopback interfaces
|
||||
}
|
||||
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, addr := range addrs {
|
||||
if ipnet, ok := addr.(*net.IPNet); ok && !ipnet.IP.IsLoopback() {
|
||||
if isPrivateIP(ipnet.IP) {
|
||||
privateIPs = append(privateIPs, fmt.Sprintf("%s (%s)", ipnet.IP.String(), iface.Name))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(privateIPs) == 0 {
|
||||
return "No private IPs found"
|
||||
}
|
||||
|
||||
return strings.Join(privateIPs, ", ")
|
||||
}
|
||||
|
||||
// isPrivateIP checks if an IP address is private
|
||||
func isPrivateIP(ip net.IP) bool {
|
||||
// RFC 1918 private address ranges
|
||||
private := []string{
|
||||
"10.0.0.0/8",
|
||||
"172.16.0.0/12",
|
||||
"192.168.0.0/16",
|
||||
}
|
||||
|
||||
for _, cidr := range private {
|
||||
_, subnet, _ := net.ParseCIDR(cidr)
|
||||
if subnet.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// FormatSystemInfoForPrompt formats system information for inclusion in diagnostic prompts
|
||||
func FormatSystemInfoForPrompt(info *SystemInfo) string {
|
||||
return fmt.Sprintf(`SYSTEM INFORMATION:
|
||||
- Hostname: %s
|
||||
- Operating System: %s
|
||||
- Kernel Version: %s
|
||||
- Architecture: %s
|
||||
- CPU Cores: %s
|
||||
- Total Memory: %s
|
||||
- System Uptime: %s
|
||||
- Current Load Average: %s
|
||||
- Root Disk Usage: %s
|
||||
- Private IP Addresses: %s
|
||||
- Go Runtime: %s
|
||||
|
||||
ISSUE DESCRIPTION:`,
|
||||
info.Hostname,
|
||||
info.OS,
|
||||
info.Kernel,
|
||||
info.Architecture,
|
||||
info.CPUCores,
|
||||
info.Memory,
|
||||
info.Uptime,
|
||||
info.LoadAverage,
|
||||
info.DiskUsage,
|
||||
info.PrivateIPs,
|
||||
runtime.Version())
|
||||
}
|
||||
@@ -1,6 +1,12 @@
|
||||
package types
|
||||
|
||||
import "time"
|
||||
import (
|
||||
"time"
|
||||
|
||||
"nannyagentv2/internal/ebpf"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
// SystemMetrics represents comprehensive system performance metrics
|
||||
type SystemMetrics struct {
|
||||
@@ -59,43 +65,47 @@ type SystemMetrics struct {
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
|
||||
// FilesystemInfo represents individual filesystem statistics
|
||||
// FilesystemInfo represents filesystem information
|
||||
type FilesystemInfo struct {
|
||||
Device string `json:"device"`
|
||||
Mountpoint string `json:"mountpoint"`
|
||||
Type string `json:"type"`
|
||||
Fstype string `json:"fstype"`
|
||||
Total uint64 `json:"total"`
|
||||
Used uint64 `json:"used"`
|
||||
Free uint64 `json:"free"`
|
||||
Usage float64 `json:"usage"`
|
||||
UsagePercent float64 `json:"usage_percent"`
|
||||
}
|
||||
|
||||
// BlockDevice represents block device information
|
||||
// BlockDevice represents a block device
|
||||
type BlockDevice struct {
|
||||
Name string `json:"name"`
|
||||
Size uint64 `json:"size"`
|
||||
Model string `json:"model"`
|
||||
Type string `json:"type"`
|
||||
Model string `json:"model,omitempty"`
|
||||
SerialNumber string `json:"serial_number"`
|
||||
}
|
||||
|
||||
// NetworkStats represents detailed network interface statistics
|
||||
// NetworkStats represents network interface statistics
|
||||
type NetworkStats struct {
|
||||
InterfaceName string `json:"interface_name"`
|
||||
BytesSent uint64 `json:"bytes_sent"`
|
||||
BytesRecv uint64 `json:"bytes_recv"`
|
||||
PacketsSent uint64 `json:"packets_sent"`
|
||||
PacketsRecv uint64 `json:"packets_recv"`
|
||||
ErrorsIn uint64 `json:"errors_in"`
|
||||
ErrorsOut uint64 `json:"errors_out"`
|
||||
DropsIn uint64 `json:"drops_in"`
|
||||
DropsOut uint64 `json:"drops_out"`
|
||||
Interface string `json:"interface"`
|
||||
BytesRecv uint64 `json:"bytes_recv"`
|
||||
BytesSent uint64 `json:"bytes_sent"`
|
||||
PacketsRecv uint64 `json:"packets_recv"`
|
||||
PacketsSent uint64 `json:"packets_sent"`
|
||||
ErrorsIn uint64 `json:"errors_in"`
|
||||
ErrorsOut uint64 `json:"errors_out"`
|
||||
DropsIn uint64 `json:"drops_in"`
|
||||
DropsOut uint64 `json:"drops_out"`
|
||||
}
|
||||
|
||||
// AuthToken represents the authentication token structure
|
||||
// AuthToken represents an authentication token
|
||||
type AuthToken struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
AgentID string `json:"agent_id"`
|
||||
}
|
||||
|
||||
@@ -169,53 +179,14 @@ type MetricsRequest struct {
|
||||
NetworkStats map[string]uint64 `json:"network_stats"`
|
||||
}
|
||||
|
||||
// eBPF related types
|
||||
type EBPFEvent struct {
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
EventType string `json:"event_type"`
|
||||
ProcessID int `json:"process_id"`
|
||||
ProcessName string `json:"process_name"`
|
||||
UserID int `json:"user_id"`
|
||||
Data map[string]interface{} `json:"data"`
|
||||
}
|
||||
|
||||
type EBPFTrace struct {
|
||||
TraceID string `json:"trace_id"`
|
||||
StartTime time.Time `json:"start_time"`
|
||||
EndTime time.Time `json:"end_time"`
|
||||
Capability string `json:"capability"`
|
||||
Events []EBPFEvent `json:"events"`
|
||||
Summary string `json:"summary"`
|
||||
EventCount int `json:"event_count"`
|
||||
ProcessList []string `json:"process_list"`
|
||||
}
|
||||
|
||||
type EBPFRequest struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"` // "tracepoint", "kprobe", "kretprobe"
|
||||
Target string `json:"target"` // tracepoint path or function name
|
||||
Duration int `json:"duration"` // seconds
|
||||
Filters map[string]string `json:"filters,omitempty"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
type NetworkEvent struct {
|
||||
Timestamp uint64 `json:"timestamp"`
|
||||
PID uint32 `json:"pid"`
|
||||
TID uint32 `json:"tid"`
|
||||
UID uint32 `json:"uid"`
|
||||
EventType string `json:"event_type"`
|
||||
Comm [16]byte `json:"-"`
|
||||
CommStr string `json:"comm"`
|
||||
}
|
||||
|
||||
// Agent types
|
||||
// Agent types for TensorZero integration
|
||||
type DiagnosticResponse struct {
|
||||
ResponseType string `json:"response_type"`
|
||||
Reasoning string `json:"reasoning"`
|
||||
Commands []Command `json:"commands"`
|
||||
}
|
||||
|
||||
// ResolutionResponse represents a resolution response
|
||||
type ResolutionResponse struct {
|
||||
ResponseType string `json:"response_type"`
|
||||
RootCause string `json:"root_cause"`
|
||||
@@ -223,12 +194,14 @@ type ResolutionResponse struct {
|
||||
Confidence string `json:"confidence"`
|
||||
}
|
||||
|
||||
// Command represents a command to execute
|
||||
type Command struct {
|
||||
ID string `json:"id"`
|
||||
Command string `json:"command"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
// CommandResult represents the result of an executed command
|
||||
type CommandResult struct {
|
||||
ID string `json:"id"`
|
||||
Command string `json:"command"`
|
||||
@@ -238,6 +211,17 @@ type CommandResult struct {
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// EBPFRequest represents an eBPF trace request from external API
|
||||
type EBPFRequest struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"` // "tracepoint", "kprobe", "kretprobe"
|
||||
Target string `json:"target"` // tracepoint path or function name
|
||||
Duration int `json:"duration"` // seconds
|
||||
Filters map[string]string `json:"filters,omitempty"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
// EBPFEnhancedDiagnosticResponse represents enhanced diagnostic response with eBPF
|
||||
type EBPFEnhancedDiagnosticResponse struct {
|
||||
ResponseType string `json:"response_type"`
|
||||
Reasoning string `json:"reasoning"`
|
||||
@@ -246,79 +230,20 @@ type EBPFEnhancedDiagnosticResponse struct {
|
||||
NextActions []string `json:"next_actions,omitempty"`
|
||||
}
|
||||
|
||||
// TensorZeroRequest represents a request to TensorZero
|
||||
type TensorZeroRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []map[string]interface{} `json:"messages"`
|
||||
EpisodeID string `json:"tensorzero::episode_id,omitempty"`
|
||||
}
|
||||
|
||||
// TensorZeroResponse represents a response from TensorZero
|
||||
type TensorZeroResponse struct {
|
||||
Choices []map[string]interface{} `json:"choices"`
|
||||
EpisodeID string `json:"episode_id"`
|
||||
}
|
||||
|
||||
// WebSocket types
|
||||
type WebSocketMessage struct {
|
||||
Type string `json:"type"`
|
||||
Data interface{} `json:"data"`
|
||||
}
|
||||
|
||||
type InvestigationTask struct {
|
||||
TaskID string `json:"task_id"`
|
||||
InvestigationID string `json:"investigation_id"`
|
||||
AgentID string `json:"agent_id"`
|
||||
DiagnosticPayload map[string]interface{} `json:"diagnostic_payload"`
|
||||
EpisodeID string `json:"episode_id,omitempty"`
|
||||
}
|
||||
|
||||
type TaskResult struct {
|
||||
TaskID string `json:"task_id"`
|
||||
Success bool `json:"success"`
|
||||
CommandResults map[string]interface{} `json:"command_results,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type HeartbeatData struct {
|
||||
AgentID string `json:"agent_id"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
// Investigation server types
|
||||
type InvestigationRequest struct {
|
||||
Issue string `json:"issue"`
|
||||
AgentID string `json:"agent_id"`
|
||||
EpisodeID string `json:"episode_id,omitempty"`
|
||||
Timestamp string `json:"timestamp,omitempty"`
|
||||
Priority string `json:"priority,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
}
|
||||
|
||||
type InvestigationResponse struct {
|
||||
Status string `json:"status"`
|
||||
Message string `json:"message"`
|
||||
Results map[string]interface{} `json:"results,omitempty"`
|
||||
AgentID string `json:"agent_id"`
|
||||
Timestamp string `json:"timestamp"`
|
||||
EpisodeID string `json:"episode_id,omitempty"`
|
||||
Investigation *PendingInvestigation `json:"investigation,omitempty"`
|
||||
}
|
||||
|
||||
type PendingInvestigation struct {
|
||||
ID string `json:"id"`
|
||||
Issue string `json:"issue"`
|
||||
AgentID string `json:"agent_id"`
|
||||
Status string `json:"status"`
|
||||
DiagnosticPayload map[string]interface{} `json:"diagnostic_payload"`
|
||||
CommandResults map[string]interface{} `json:"command_results,omitempty"`
|
||||
EpisodeID *string `json:"episode_id,omitempty"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
StartedAt *string `json:"started_at,omitempty"`
|
||||
CompletedAt *string `json:"completed_at,omitempty"`
|
||||
ErrorMessage *string `json:"error_message,omitempty"`
|
||||
}
|
||||
|
||||
// System types
|
||||
// SystemInfo represents system information (for compatibility)
|
||||
type SystemInfo struct {
|
||||
Hostname string `json:"hostname"`
|
||||
Platform string `json:"platform"`
|
||||
@@ -331,7 +256,35 @@ type SystemInfo struct {
|
||||
DiskInfo []map[string]string `json:"disk_info"`
|
||||
}
|
||||
|
||||
// Executor types
|
||||
type CommandExecutor struct {
|
||||
timeout time.Duration
|
||||
// AgentConfig represents agent configuration
|
||||
type AgentConfig struct {
|
||||
TensorZeroAPIKey string `json:"tensorzero_api_key"`
|
||||
APIURL string `json:"api_url"`
|
||||
Timeout int `json:"timeout"`
|
||||
Debug bool `json:"debug"`
|
||||
MaxRetries int `json:"max_retries"`
|
||||
BackoffFactor int `json:"backoff_factor"`
|
||||
EpisodeID string `json:"episode_id,omitempty"`
|
||||
}
|
||||
|
||||
// PendingInvestigation represents a pending investigation from the database
|
||||
type PendingInvestigation struct {
|
||||
ID string `json:"id"`
|
||||
InvestigationID string `json:"investigation_id"`
|
||||
AgentID string `json:"agent_id"`
|
||||
DiagnosticPayload map[string]interface{} `json:"diagnostic_payload"`
|
||||
EpisodeID *string `json:"episode_id"`
|
||||
Status string `json:"status"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// DiagnosticAgent interface for agent functionality needed by other packages
|
||||
type DiagnosticAgent interface {
|
||||
DiagnoseIssue(issue string) error
|
||||
// Exported method names to match what websocket client calls
|
||||
ConvertEBPFProgramsToTraceSpecs(ebpfRequests []EBPFRequest) []ebpf.TraceSpec
|
||||
ExecuteEBPFTraces(traceSpecs []ebpf.TraceSpec) []map[string]interface{}
|
||||
SendRequestWithEpisode(messages []openai.ChatCompletionMessage, episodeID string) (*openai.ChatCompletionResponse, error)
|
||||
SendRequest(messages []openai.ChatCompletionMessage) (*openai.ChatCompletionResponse, error)
|
||||
ExecuteCommand(cmd Command) CommandResult
|
||||
}
|
||||
|
||||
842
internal/websocket/websocket_client.go
Normal file
842
internal/websocket/websocket_client.go
Normal file
@@ -0,0 +1,842 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"nannyagentv2/internal/auth"
|
||||
"nannyagentv2/internal/logging"
|
||||
"nannyagentv2/internal/metrics"
|
||||
"nannyagentv2/internal/types"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
// Helper function for minimum of two integers
|
||||
|
||||
// WebSocketMessage represents a message sent over WebSocket
|
||||
type WebSocketMessage struct {
|
||||
Type string `json:"type"`
|
||||
Data interface{} `json:"data"`
|
||||
}
|
||||
|
||||
// InvestigationTask represents a task sent to the agent
|
||||
type InvestigationTask struct {
|
||||
TaskID string `json:"task_id"`
|
||||
InvestigationID string `json:"investigation_id"`
|
||||
AgentID string `json:"agent_id"`
|
||||
DiagnosticPayload map[string]interface{} `json:"diagnostic_payload"`
|
||||
EpisodeID string `json:"episode_id,omitempty"`
|
||||
}
|
||||
|
||||
// TaskResult represents the result of a completed task
|
||||
type TaskResult struct {
|
||||
TaskID string `json:"task_id"`
|
||||
Success bool `json:"success"`
|
||||
CommandResults map[string]interface{} `json:"command_results,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// HeartbeatData represents heartbeat information
|
||||
type HeartbeatData struct {
|
||||
AgentID string `json:"agent_id"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
// WebSocketClient handles WebSocket connection to Supabase backend
|
||||
type WebSocketClient struct {
|
||||
agent types.DiagnosticAgent // DiagnosticAgent interface
|
||||
conn *websocket.Conn
|
||||
agentID string
|
||||
authManager *auth.AuthManager
|
||||
metricsCollector *metrics.Collector
|
||||
supabaseURL string
|
||||
token string
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
consecutiveFailures int // Track consecutive connection failures
|
||||
}
|
||||
|
||||
// NewWebSocketClient creates a new WebSocket client
|
||||
func NewWebSocketClient(agent types.DiagnosticAgent, authManager *auth.AuthManager) *WebSocketClient {
|
||||
// Get agent ID from authentication system
|
||||
var agentID string
|
||||
if authManager != nil {
|
||||
if id, err := authManager.GetCurrentAgentID(); err == nil {
|
||||
agentID = id
|
||||
// Agent ID retrieved successfully
|
||||
} else {
|
||||
logging.Error("Failed to get agent ID from auth manager: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to environment variable or generate one if auth fails
|
||||
if agentID == "" {
|
||||
agentID = os.Getenv("AGENT_ID")
|
||||
if agentID == "" {
|
||||
agentID = fmt.Sprintf("agent-%d", time.Now().Unix())
|
||||
}
|
||||
}
|
||||
|
||||
supabaseURL := os.Getenv("SUPABASE_PROJECT_URL")
|
||||
if supabaseURL == "" {
|
||||
log.Fatal("❌ SUPABASE_PROJECT_URL environment variable is required")
|
||||
}
|
||||
|
||||
// Create metrics collector
|
||||
metricsCollector := metrics.NewCollector("v2.0.0")
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
return &WebSocketClient{
|
||||
agent: agent,
|
||||
agentID: agentID,
|
||||
authManager: authManager,
|
||||
metricsCollector: metricsCollector,
|
||||
supabaseURL: supabaseURL,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the WebSocket connection and message handling
|
||||
func (w *WebSocketClient) Start() error {
|
||||
// Starting WebSocket client
|
||||
|
||||
if err := w.connect(); err != nil {
|
||||
return fmt.Errorf("failed to establish WebSocket connection: %v", err)
|
||||
}
|
||||
|
||||
// Start message reading loop
|
||||
go w.handleMessages()
|
||||
|
||||
// Start heartbeat
|
||||
go w.startHeartbeat()
|
||||
|
||||
// Start database polling for pending investigations
|
||||
go w.pollPendingInvestigations()
|
||||
|
||||
// WebSocket client started
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop closes the WebSocket connection
|
||||
func (c *WebSocketClient) Stop() {
|
||||
c.cancel()
|
||||
if c.conn != nil {
|
||||
c.conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// getAuthToken retrieves authentication token
|
||||
func (c *WebSocketClient) getAuthToken() error {
|
||||
if c.authManager == nil {
|
||||
return fmt.Errorf("auth manager not available")
|
||||
}
|
||||
|
||||
token, err := c.authManager.EnsureAuthenticated()
|
||||
if err != nil {
|
||||
return fmt.Errorf("authentication failed: %v", err)
|
||||
}
|
||||
|
||||
c.token = token.AccessToken
|
||||
return nil
|
||||
}
|
||||
|
||||
// connect establishes WebSocket connection
|
||||
func (c *WebSocketClient) connect() error {
|
||||
// Get fresh auth token
|
||||
if err := c.getAuthToken(); err != nil {
|
||||
return fmt.Errorf("failed to get auth token: %v", err)
|
||||
}
|
||||
|
||||
// Convert HTTP URL to WebSocket URL
|
||||
wsURL := strings.Replace(c.supabaseURL, "https://", "wss://", 1)
|
||||
wsURL = strings.Replace(wsURL, "http://", "ws://", 1)
|
||||
wsURL += "/functions/v1/websocket-agent-handler"
|
||||
|
||||
// Connecting to WebSocket
|
||||
|
||||
// Set up headers
|
||||
headers := http.Header{}
|
||||
headers.Set("Authorization", "Bearer "+c.token)
|
||||
|
||||
// Connect
|
||||
dialer := websocket.Dialer{
|
||||
HandshakeTimeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
conn, resp, err := dialer.Dial(wsURL, headers)
|
||||
if err != nil {
|
||||
c.consecutiveFailures++
|
||||
if c.consecutiveFailures >= 5 && resp != nil {
|
||||
logging.Error("WebSocket handshake failed with status: %d (failure #%d)", resp.StatusCode, c.consecutiveFailures)
|
||||
}
|
||||
return fmt.Errorf("websocket connection failed: %v", err)
|
||||
}
|
||||
|
||||
c.conn = conn
|
||||
// WebSocket client connected
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleMessages processes incoming WebSocket messages
|
||||
func (c *WebSocketClient) handleMessages() {
|
||||
defer func() {
|
||||
if c.conn != nil {
|
||||
// Closing WebSocket connection
|
||||
c.conn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
// Started WebSocket message listener
|
||||
connectionStart := time.Now()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
// Only log context cancellation if there have been failures
|
||||
if c.consecutiveFailures >= 5 {
|
||||
logging.Debug("Context cancelled after %v, stopping message handler", time.Since(connectionStart))
|
||||
}
|
||||
return
|
||||
default:
|
||||
// Set read deadline to detect connection issues
|
||||
c.conn.SetReadDeadline(time.Now().Add(90 * time.Second))
|
||||
|
||||
var message WebSocketMessage
|
||||
readStart := time.Now()
|
||||
err := c.conn.ReadJSON(&message)
|
||||
readDuration := time.Since(readStart)
|
||||
|
||||
if err != nil {
|
||||
connectionDuration := time.Since(connectionStart)
|
||||
|
||||
// Only log specific errors after failure threshold
|
||||
if c.consecutiveFailures >= 5 {
|
||||
if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
|
||||
logging.Debug("WebSocket closed normally after %v: %v", connectionDuration, err)
|
||||
} else if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||||
logging.Error("ABNORMAL CLOSE after %v (code 1006 = server-side timeout/kill): %v", connectionDuration, err)
|
||||
logging.Debug("Last read took %v, connection lived %v", readDuration, connectionDuration)
|
||||
} else if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
logging.Warning("READ TIMEOUT after %v: %v", connectionDuration, err)
|
||||
} else {
|
||||
logging.Error("WebSocket error after %v: %v", connectionDuration, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Track consecutive failures for diagnostic threshold
|
||||
c.consecutiveFailures++
|
||||
|
||||
// Only show diagnostics after multiple failures
|
||||
if c.consecutiveFailures >= 5 {
|
||||
logging.Debug("DIAGNOSTIC - Connection failed #%d after %v", c.consecutiveFailures, connectionDuration)
|
||||
}
|
||||
|
||||
// Attempt reconnection instead of returning immediately
|
||||
go c.attemptReconnection()
|
||||
return
|
||||
}
|
||||
|
||||
// Received WebSocket message successfully - reset failure counter
|
||||
c.consecutiveFailures = 0
|
||||
|
||||
switch message.Type {
|
||||
case "connection_ack":
|
||||
// Connection acknowledged
|
||||
|
||||
case "heartbeat_ack":
|
||||
// Heartbeat acknowledged
|
||||
|
||||
case "investigation_task":
|
||||
// Received investigation task - processing
|
||||
go c.handleInvestigationTask(message.Data)
|
||||
|
||||
case "task_result_ack":
|
||||
// Task result acknowledged
|
||||
|
||||
default:
|
||||
logging.Warning("Unknown message type: %s", message.Type)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleInvestigationTask processes investigation tasks from the backend
|
||||
func (c *WebSocketClient) handleInvestigationTask(data interface{}) {
|
||||
// Parse task data
|
||||
taskBytes, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
logging.Error("Error marshaling task data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var task InvestigationTask
|
||||
err = json.Unmarshal(taskBytes, &task)
|
||||
if err != nil {
|
||||
logging.Error("Error unmarshaling investigation task: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Processing investigation task
|
||||
|
||||
// Execute diagnostic commands
|
||||
results, err := c.executeDiagnosticCommands(task.DiagnosticPayload)
|
||||
|
||||
// Prepare task result
|
||||
taskResult := TaskResult{
|
||||
TaskID: task.TaskID,
|
||||
Success: err == nil,
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
taskResult.Error = err.Error()
|
||||
logging.Error("Task execution failed: %v", err)
|
||||
} else {
|
||||
taskResult.CommandResults = results
|
||||
// Task executed successfully
|
||||
}
|
||||
|
||||
// Send result back
|
||||
c.sendTaskResult(taskResult)
|
||||
}
|
||||
|
||||
// executeDiagnosticCommands executes the commands from a diagnostic response
|
||||
func (c *WebSocketClient) executeDiagnosticCommands(diagnosticPayload map[string]interface{}) (map[string]interface{}, error) {
|
||||
results := map[string]interface{}{
|
||||
"agent_id": c.agentID,
|
||||
"execution_time": time.Now().UTC().Format(time.RFC3339),
|
||||
"command_results": []map[string]interface{}{},
|
||||
}
|
||||
|
||||
// Extract commands from diagnostic payload
|
||||
commands, ok := diagnosticPayload["commands"].([]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no commands found in diagnostic payload")
|
||||
}
|
||||
|
||||
var commandResults []map[string]interface{}
|
||||
|
||||
for _, cmd := range commands {
|
||||
cmdMap, ok := cmd.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
id, _ := cmdMap["id"].(string)
|
||||
command, _ := cmdMap["command"].(string)
|
||||
description, _ := cmdMap["description"].(string)
|
||||
|
||||
if command == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Executing command
|
||||
|
||||
// Execute the command
|
||||
output, exitCode, err := c.executeCommand(command)
|
||||
|
||||
result := map[string]interface{}{
|
||||
"id": id,
|
||||
"command": command,
|
||||
"description": description,
|
||||
"output": output,
|
||||
"exit_code": exitCode,
|
||||
"success": err == nil && exitCode == 0,
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
result["error"] = err.Error()
|
||||
logging.Warning("Command [%s] failed: %v (exit code: %d)", id, err, exitCode)
|
||||
}
|
||||
|
||||
commandResults = append(commandResults, result)
|
||||
}
|
||||
|
||||
results["command_results"] = commandResults
|
||||
results["total_commands"] = len(commandResults)
|
||||
results["successful_commands"] = c.countSuccessfulCommands(commandResults)
|
||||
|
||||
// Execute eBPF programs if present
|
||||
ebpfPrograms, hasEBPF := diagnosticPayload["ebpf_programs"].([]interface{})
|
||||
if hasEBPF && len(ebpfPrograms) > 0 {
|
||||
ebpfResults := c.executeEBPFPrograms(ebpfPrograms)
|
||||
results["ebpf_results"] = ebpfResults
|
||||
results["total_ebpf_programs"] = len(ebpfPrograms)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// executeEBPFPrograms executes eBPF monitoring programs using the real eBPF manager
|
||||
func (c *WebSocketClient) executeEBPFPrograms(ebpfPrograms []interface{}) []map[string]interface{} {
|
||||
var ebpfRequests []types.EBPFRequest
|
||||
|
||||
// Convert interface{} to EBPFRequest structs
|
||||
for _, prog := range ebpfPrograms {
|
||||
progMap, ok := prog.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
name, _ := progMap["name"].(string)
|
||||
progType, _ := progMap["type"].(string)
|
||||
target, _ := progMap["target"].(string)
|
||||
duration, _ := progMap["duration"].(float64)
|
||||
description, _ := progMap["description"].(string)
|
||||
|
||||
if name == "" || progType == "" || target == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
ebpfRequests = append(ebpfRequests, types.EBPFRequest{
|
||||
Name: name,
|
||||
Type: progType,
|
||||
Target: target,
|
||||
Duration: int(duration),
|
||||
Description: description,
|
||||
})
|
||||
}
|
||||
|
||||
// Execute eBPF programs using the agent's new BCC concurrent execution logic
|
||||
traceSpecs := c.agent.ConvertEBPFProgramsToTraceSpecs(ebpfRequests)
|
||||
return c.agent.ExecuteEBPFTraces(traceSpecs)
|
||||
}
|
||||
|
||||
// executeCommandsFromPayload executes commands from a payload and returns results
|
||||
func (c *WebSocketClient) executeCommandsFromPayload(commands []interface{}) []map[string]interface{} {
|
||||
var commandResults []map[string]interface{}
|
||||
|
||||
for _, cmd := range commands {
|
||||
cmdMap, ok := cmd.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
id, _ := cmdMap["id"].(string)
|
||||
command, _ := cmdMap["command"].(string)
|
||||
description, _ := cmdMap["description"].(string)
|
||||
|
||||
if command == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Execute the command
|
||||
output, exitCode, err := c.executeCommand(command)
|
||||
|
||||
result := map[string]interface{}{
|
||||
"id": id,
|
||||
"command": command,
|
||||
"description": description,
|
||||
"output": output,
|
||||
"exit_code": exitCode,
|
||||
"success": err == nil && exitCode == 0,
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
result["error"] = err.Error()
|
||||
logging.Warning("Command [%s] failed: %v (exit code: %d)", id, err, exitCode)
|
||||
}
|
||||
|
||||
commandResults = append(commandResults, result)
|
||||
}
|
||||
|
||||
return commandResults
|
||||
}
|
||||
|
||||
// executeCommand executes a shell command and returns output, exit code, and error
|
||||
func (c *WebSocketClient) executeCommand(command string) (string, int, error) {
|
||||
// Parse command into parts
|
||||
parts := strings.Fields(command)
|
||||
if len(parts) == 0 {
|
||||
return "", -1, fmt.Errorf("empty command")
|
||||
}
|
||||
|
||||
// Create command with timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, parts[0], parts[1:]...)
|
||||
cmd.Env = os.Environ()
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
exitCode := 0
|
||||
|
||||
if err != nil {
|
||||
if exitError, ok := err.(*exec.ExitError); ok {
|
||||
exitCode = exitError.ExitCode()
|
||||
} else {
|
||||
exitCode = -1
|
||||
}
|
||||
}
|
||||
|
||||
return string(output), exitCode, err
|
||||
}
|
||||
|
||||
// countSuccessfulCommands counts the number of successful commands
|
||||
func (c *WebSocketClient) countSuccessfulCommands(results []map[string]interface{}) int {
|
||||
count := 0
|
||||
for _, result := range results {
|
||||
if success, ok := result["success"].(bool); ok && success {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// sendTaskResult sends a task result back to the backend
|
||||
func (c *WebSocketClient) sendTaskResult(result TaskResult) {
|
||||
message := WebSocketMessage{
|
||||
Type: "task_result",
|
||||
Data: result,
|
||||
}
|
||||
|
||||
err := c.conn.WriteJSON(message)
|
||||
if err != nil {
|
||||
logging.Error("Error sending task result: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// startHeartbeat sends periodic heartbeat messages
|
||||
func (c *WebSocketClient) startHeartbeat() {
|
||||
ticker := time.NewTicker(30 * time.Second) // Heartbeat every 30 seconds
|
||||
defer ticker.Stop()
|
||||
|
||||
// Starting heartbeat
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
logging.Debug("Heartbeat stopped due to context cancellation")
|
||||
return
|
||||
case <-ticker.C:
|
||||
// Sending heartbeat
|
||||
heartbeat := WebSocketMessage{
|
||||
Type: "heartbeat",
|
||||
Data: HeartbeatData{
|
||||
AgentID: c.agentID,
|
||||
Timestamp: time.Now(),
|
||||
Version: "v2.0.0",
|
||||
},
|
||||
}
|
||||
|
||||
err := c.conn.WriteJSON(heartbeat)
|
||||
if err != nil {
|
||||
logging.Error("Error sending heartbeat: %v", err)
|
||||
logging.Debug("Heartbeat failed, connection likely dead")
|
||||
return
|
||||
}
|
||||
// Heartbeat sent
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// pollPendingInvestigations polls the database for pending investigations
|
||||
func (c *WebSocketClient) pollPendingInvestigations() {
|
||||
// Starting database polling
|
||||
ticker := time.NewTicker(5 * time.Second) // Poll every 5 seconds
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
c.checkForPendingInvestigations()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// checkForPendingInvestigations checks the database for new pending investigations via proxy
|
||||
func (c *WebSocketClient) checkForPendingInvestigations() {
|
||||
// Use Edge Function proxy instead of direct database access
|
||||
url := fmt.Sprintf("%s/functions/v1/agent-database-proxy/pending-investigations", c.supabaseURL)
|
||||
|
||||
// Poll database for pending investigations
|
||||
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
// Request creation failed
|
||||
return
|
||||
}
|
||||
|
||||
// Only JWT token needed for proxy - no API keys exposed
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.token))
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
// Database request failed
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return
|
||||
}
|
||||
|
||||
var investigations []types.PendingInvestigation
|
||||
err = json.NewDecoder(resp.Body).Decode(&investigations)
|
||||
if err != nil {
|
||||
// Response decode failed
|
||||
return
|
||||
}
|
||||
|
||||
for _, investigation := range investigations {
|
||||
go c.handlePendingInvestigation(investigation)
|
||||
}
|
||||
}
|
||||
|
||||
// handlePendingInvestigation processes a pending investigation from database polling
|
||||
func (c *WebSocketClient) handlePendingInvestigation(investigation types.PendingInvestigation) {
|
||||
// Processing pending investigation
|
||||
|
||||
// Mark as executing
|
||||
err := c.updateInvestigationStatus(investigation.ID, "executing", nil, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Execute diagnostic commands
|
||||
results, err := c.executeDiagnosticCommands(investigation.DiagnosticPayload)
|
||||
|
||||
// Prepare the base results map we'll send to DB
|
||||
resultsForDB := map[string]interface{}{
|
||||
"agent_id": c.agentID,
|
||||
"execution_time": time.Now().UTC().Format(time.RFC3339),
|
||||
"command_results": results,
|
||||
}
|
||||
|
||||
// If command execution failed, mark investigation as failed
|
||||
if err != nil {
|
||||
errorMsg := err.Error()
|
||||
// Include partial results when possible
|
||||
if results != nil {
|
||||
resultsForDB["command_results"] = results
|
||||
}
|
||||
c.updateInvestigationStatus(investigation.ID, "failed", resultsForDB, &errorMsg)
|
||||
// Investigation failed
|
||||
return
|
||||
}
|
||||
|
||||
// Try to continue the TensorZero conversation by sending command results back
|
||||
// Build messages: assistant = diagnostic payload, user = command results
|
||||
diagJSON, _ := json.Marshal(investigation.DiagnosticPayload)
|
||||
commandsJSON, _ := json.MarshalIndent(results, "", " ")
|
||||
|
||||
messages := []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleAssistant,
|
||||
Content: string(diagJSON),
|
||||
},
|
||||
{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: string(commandsJSON),
|
||||
},
|
||||
}
|
||||
|
||||
// Use the episode ID from the investigation to maintain conversation continuity
|
||||
episodeID := ""
|
||||
if investigation.EpisodeID != nil {
|
||||
episodeID = *investigation.EpisodeID
|
||||
}
|
||||
|
||||
// Continue conversation until resolution (same as agent)
|
||||
var finalAIContent string
|
||||
for {
|
||||
tzResp, tzErr := c.agent.SendRequestWithEpisode(messages, episodeID)
|
||||
if tzErr != nil {
|
||||
logging.Warning("TensorZero continuation failed: %v", tzErr)
|
||||
// Fall back to marking completed with command results only
|
||||
c.updateInvestigationStatus(investigation.ID, "completed", resultsForDB, nil)
|
||||
return
|
||||
}
|
||||
|
||||
if len(tzResp.Choices) == 0 {
|
||||
logging.Warning("No choices in TensorZero response")
|
||||
c.updateInvestigationStatus(investigation.ID, "completed", resultsForDB, nil)
|
||||
return
|
||||
}
|
||||
|
||||
aiContent := tzResp.Choices[0].Message.Content
|
||||
if len(aiContent) > 300 {
|
||||
// AI response received successfully
|
||||
} else {
|
||||
logging.Debug("AI Response: %s", aiContent)
|
||||
}
|
||||
|
||||
// Check if this is a resolution response (final)
|
||||
var resolutionResp struct {
|
||||
ResponseType string `json:"response_type"`
|
||||
RootCause string `json:"root_cause"`
|
||||
ResolutionPlan string `json:"resolution_plan"`
|
||||
Confidence string `json:"confidence"`
|
||||
}
|
||||
|
||||
logging.Debug("Analyzing AI response type...")
|
||||
|
||||
if err := json.Unmarshal([]byte(aiContent), &resolutionResp); err == nil && resolutionResp.ResponseType == "resolution" {
|
||||
// This is the final resolution - show summary and complete
|
||||
logging.Info("=== DIAGNOSIS COMPLETE ===")
|
||||
logging.Info("Root Cause: %s", resolutionResp.RootCause)
|
||||
logging.Info("Resolution Plan: %s", resolutionResp.ResolutionPlan)
|
||||
logging.Info("Confidence: %s", resolutionResp.Confidence)
|
||||
finalAIContent = aiContent
|
||||
break
|
||||
}
|
||||
|
||||
// Check if this is another diagnostic response requiring more commands
|
||||
var diagnosticResp struct {
|
||||
ResponseType string `json:"response_type"`
|
||||
Commands []interface{} `json:"commands"`
|
||||
EBPFPrograms []interface{} `json:"ebpf_programs"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(aiContent), &diagnosticResp); err == nil && diagnosticResp.ResponseType == "diagnostic" {
|
||||
logging.Debug("AI requested additional diagnostics, executing...")
|
||||
|
||||
// Execute additional commands if any
|
||||
additionalResults := map[string]interface{}{
|
||||
"command_results": []map[string]interface{}{},
|
||||
}
|
||||
|
||||
if len(diagnosticResp.Commands) > 0 {
|
||||
logging.Debug("Executing %d additional diagnostic commands", len(diagnosticResp.Commands))
|
||||
commandResults := c.executeCommandsFromPayload(diagnosticResp.Commands)
|
||||
additionalResults["command_results"] = commandResults
|
||||
}
|
||||
|
||||
// Execute additional eBPF programs if any
|
||||
if len(diagnosticResp.EBPFPrograms) > 0 {
|
||||
ebpfResults := c.executeEBPFPrograms(diagnosticResp.EBPFPrograms)
|
||||
additionalResults["ebpf_results"] = ebpfResults
|
||||
}
|
||||
|
||||
// Add AI response and additional results to conversation
|
||||
messages = append(messages, openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleAssistant,
|
||||
Content: aiContent,
|
||||
})
|
||||
|
||||
additionalResultsJSON, _ := json.MarshalIndent(additionalResults, "", " ")
|
||||
messages = append(messages, openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: string(additionalResultsJSON),
|
||||
})
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
// If neither resolution nor diagnostic, treat as final response
|
||||
logging.Warning("Unknown response type - treating as final response")
|
||||
finalAIContent = aiContent
|
||||
break
|
||||
}
|
||||
|
||||
// Attach final AI response to results for DB and mark as completed_with_analysis
|
||||
resultsForDB["ai_response"] = finalAIContent
|
||||
c.updateInvestigationStatus(investigation.ID, "completed_with_analysis", resultsForDB, nil)
|
||||
}
|
||||
|
||||
// updateInvestigationStatus updates the status of a pending investigation
|
||||
func (c *WebSocketClient) updateInvestigationStatus(id, status string, results map[string]interface{}, errorMsg *string) error {
|
||||
updateData := map[string]interface{}{
|
||||
"status": status,
|
||||
}
|
||||
|
||||
if status == "executing" {
|
||||
updateData["started_at"] = time.Now().UTC().Format(time.RFC3339)
|
||||
} else if status == "completed" {
|
||||
updateData["completed_at"] = time.Now().UTC().Format(time.RFC3339)
|
||||
if results != nil {
|
||||
updateData["command_results"] = results
|
||||
}
|
||||
} else if status == "failed" && errorMsg != nil {
|
||||
updateData["error_message"] = *errorMsg
|
||||
updateData["completed_at"] = time.Now().UTC().Format(time.RFC3339)
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(updateData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal update data: %v", err)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/functions/v1/agent-database-proxy/pending-investigations/%s", c.supabaseURL, id)
|
||||
req, err := http.NewRequest("PATCH", url, strings.NewReader(string(jsonData)))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %v", err)
|
||||
}
|
||||
|
||||
// Only JWT token needed for proxy - no API keys exposed
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.token))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update investigation: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 && resp.StatusCode != 204 {
|
||||
return fmt.Errorf("supabase update error: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// attemptReconnection attempts to reconnect the WebSocket with backoff
|
||||
func (c *WebSocketClient) attemptReconnection() {
|
||||
backoffDurations := []time.Duration{
|
||||
2 * time.Second,
|
||||
5 * time.Second,
|
||||
10 * time.Second,
|
||||
20 * time.Second,
|
||||
30 * time.Second,
|
||||
}
|
||||
|
||||
for i, backoff := range backoffDurations {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
default:
|
||||
c.consecutiveFailures++
|
||||
|
||||
// Only show messages after 5 consecutive failures
|
||||
if c.consecutiveFailures >= 5 {
|
||||
logging.Info("Attempting WebSocket reconnection (attempt %d/%d) - %d consecutive failures", i+1, len(backoffDurations), c.consecutiveFailures)
|
||||
}
|
||||
|
||||
time.Sleep(backoff)
|
||||
|
||||
if err := c.connect(); err != nil {
|
||||
if c.consecutiveFailures >= 5 {
|
||||
logging.Warning("Reconnection attempt %d failed: %v", i+1, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Successfully reconnected - reset failure counter
|
||||
if c.consecutiveFailures >= 5 {
|
||||
logging.Info("WebSocket reconnected successfully after %d failures", c.consecutiveFailures)
|
||||
}
|
||||
c.consecutiveFailures = 0
|
||||
go c.handleMessages() // Restart message handling
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
logging.Error("Failed to reconnect after %d attempts, giving up", len(backoffDurations))
|
||||
}
|
||||
Reference in New Issue
Block a user