diff --git a/.gitignore b/.gitignore index 5d5e6ae..3f90785 100644 --- a/.gitignore +++ b/.gitignore @@ -23,6 +23,7 @@ go.work go.work.sum # env file -.env +.env* nannyagent* -nanny-agent* \ No newline at end of file +nanny-agent* +.vscode diff --git a/Dockerfile b/Dockerfile deleted file mode 100644 index 473a0f4..0000000 diff --git a/go.mod b/go.mod index 4c0da5e..11af360 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,18 @@ toolchain go1.24.2 require ( github.com/cilium/ebpf v0.19.0 + github.com/joho/godotenv v1.5.1 github.com/sashabaranov/go-openai v1.32.0 + github.com/shirou/gopsutil/v3 v3.24.5 ) -require golang.org/x/sys v0.31.0 // indirect +require ( + github.com/go-ole/go-ole v1.2.6 // indirect + github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect + github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect + github.com/shoenig/go-m1cpu v0.1.6 // indirect + github.com/tklauser/go-sysconf v0.3.12 // indirect + github.com/tklauser/numcpus v0.6.1 // indirect + github.com/yusufpapurcu/wmi v1.2.4 // indirect + golang.org/x/sys v0.31.0 // indirect +) diff --git a/go.sum b/go.sum index b8de438..451412b 100644 --- a/go.sum +++ b/go.sum @@ -1,9 +1,16 @@ github.com/cilium/ebpf v0.19.0 h1:Ro/rE64RmFBeA9FGjcTc+KmCeY6jXmryu6FfnzPRIao= github.com/cilium/ebpf v0.19.0/go.mod h1:fLCgMo3l8tZmAdM3B2XqdFzXBpwkcSTroaVqN08OWVY= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= +github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-quicktest/qt v1.101.1-0.20240301121107-c6c8733fa1e6 h1:teYtXy9B7y5lHTp8V9KPxpYRAVA7dozigQcMiBust1s= github.com/go-quicktest/qt v1.101.1-0.20240301121107-c6c8733fa1e6/go.mod h1:p4lGIVX+8Wa6ZPNDvqcxq36XpUDLh42FLetFU7odllI= +github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA= github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w= github.com/jsimonetti/rtnetlink/v2 v2.0.1 h1:xda7qaHDSVOsADNouv7ukSuicKZO7GgVUCXxpaIEIlM= @@ -12,17 +19,44 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4= +github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g= github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw= github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U= github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw= +github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/sashabaranov/go-openai v1.32.0 h1:Yk3iE9moX3RBXxrof3OBtUBrE7qZR0zF9ebsoO4zVzI= github.com/sashabaranov/go-openai v1.32.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= +github.com/shirou/gopsutil/v3 v3.24.5 h1:i0t8kL+kQTvpAYToeuiVk3TgDeKOFioZO3Ztz/iZ9pI= +github.com/shirou/gopsutil/v3 v3.24.5/go.mod h1:bsoOS1aStSs9ErQ1WWfxllSeS1K5D+U30r2NfcubMVk= +github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM= +github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ= +github.com/shoenig/test v0.6.4 h1:kVTaSd7WLz5WZ2IaoM0RSzRsUD+m8wRR+5qvntpn4LU= +github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnjqq0k= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU= +github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI= +github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk= +github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY= +github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= +github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/auth/auth.go b/internal/auth/auth.go new file mode 100644 index 0000000..8c9af45 --- /dev/null +++ b/internal/auth/auth.go @@ -0,0 +1,410 @@ +package auth + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strings" + "time" + + "nannyagentv2/internal/config" + "nannyagentv2/internal/types" +) + +const ( + // Token storage location (secure directory) + TokenStorageDir = "/var/lib/nannyagent" + TokenStorageFile = ".agent_token.json" + RefreshTokenFile = ".refresh_token" + + // Polling configuration + MaxPollAttempts = 60 // 5 minutes (60 * 5 seconds) + PollInterval = 5 * time.Second +) + +// AuthManager handles all authentication-related operations +type AuthManager struct { + config *config.Config + client *http.Client +} + +// NewAuthManager creates a new authentication manager +func NewAuthManager(cfg *config.Config) *AuthManager { + return &AuthManager{ + config: cfg, + client: &http.Client{ + Timeout: 30 * time.Second, + }, + } +} + +// EnsureTokenStorageDir creates the token storage directory if it doesn't exist +func (am *AuthManager) EnsureTokenStorageDir() error { + // Check if running as root + if os.Geteuid() != 0 { + return fmt.Errorf("must run as root to create secure token storage directory") + } + + // Create directory with restricted permissions (0700 - only root can access) + if err := os.MkdirAll(TokenStorageDir, 0700); err != nil { + return fmt.Errorf("failed to create token storage directory: %w", err) + } + + return nil +} + +// StartDeviceAuthorization initiates the OAuth device authorization flow +func (am *AuthManager) StartDeviceAuthorization() (*types.DeviceAuthResponse, error) { + payload := map[string]interface{}{ + "client_id": "nannyagent-cli", + "scope": []string{"agent:register"}, + } + + jsonData, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("failed to marshal payload: %w", err) + } + + url := fmt.Sprintf("%s/device/authorize", am.config.DeviceAuthURL) + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + + resp, err := am.client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to start device authorization: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("device authorization failed with status %d: %s", resp.StatusCode, string(body)) + } + + var deviceResp types.DeviceAuthResponse + if err := json.Unmarshal(body, &deviceResp); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + return &deviceResp, nil +} + +// PollForToken polls the token endpoint until authorization is complete +func (am *AuthManager) PollForToken(deviceCode string) (*types.TokenResponse, error) { + fmt.Println("ā³ Waiting for user authorization...") + + for attempts := 0; attempts < MaxPollAttempts; attempts++ { + tokenReq := types.TokenRequest{ + GrantType: "urn:ietf:params:oauth:grant-type:device_code", + DeviceCode: deviceCode, + } + + jsonData, err := json.Marshal(tokenReq) + if err != nil { + return nil, fmt.Errorf("failed to marshal token request: %w", err) + } + + url := fmt.Sprintf("%s/token", am.config.DeviceAuthURL) + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create token request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + + resp, err := am.client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to poll for token: %w", err) + } + + body, err := io.ReadAll(resp.Body) + resp.Body.Close() + + if err != nil { + return nil, fmt.Errorf("failed to read token response: %w", err) + } + + var tokenResp types.TokenResponse + if err := json.Unmarshal(body, &tokenResp); err != nil { + return nil, fmt.Errorf("failed to parse token response: %w", err) + } + + if tokenResp.Error != "" { + if tokenResp.Error == "authorization_pending" { + fmt.Print(".") + time.Sleep(PollInterval) + continue + } + return nil, fmt.Errorf("authorization failed: %s", tokenResp.ErrorDescription) + } + + if tokenResp.AccessToken != "" { + fmt.Println("\nāœ… Authorization successful!") + return &tokenResp, nil + } + + time.Sleep(PollInterval) + } + + return nil, fmt.Errorf("authorization timed out after %d attempts", MaxPollAttempts) +} + +// RefreshAccessToken refreshes an expired access token using the refresh token +func (am *AuthManager) RefreshAccessToken(refreshToken string) (*types.TokenResponse, error) { + tokenReq := types.TokenRequest{ + GrantType: "refresh_token", + RefreshToken: refreshToken, + } + + jsonData, err := json.Marshal(tokenReq) + if err != nil { + return nil, fmt.Errorf("failed to marshal refresh request: %w", err) + } + + url := fmt.Sprintf("%s/token", am.config.DeviceAuthURL) + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create refresh request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + + resp, err := am.client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to refresh token: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read refresh response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("token refresh failed with status %d: %s", resp.StatusCode, string(body)) + } + + var tokenResp types.TokenResponse + if err := json.Unmarshal(body, &tokenResp); err != nil { + return nil, fmt.Errorf("failed to parse refresh response: %w", err) + } + + if tokenResp.Error != "" { + return nil, fmt.Errorf("token refresh failed: %s", tokenResp.ErrorDescription) + } + + return &tokenResp, nil +} + +// SaveToken saves the authentication token to secure local storage +func (am *AuthManager) SaveToken(token *types.AuthToken) error { + if err := am.EnsureTokenStorageDir(); err != nil { + return fmt.Errorf("failed to ensure token storage directory: %w", err) + } + + // Save main token file + tokenPath := am.getTokenPath() + jsonData, err := json.MarshalIndent(token, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal token: %w", err) + } + + if err := os.WriteFile(tokenPath, jsonData, 0600); err != nil { + return fmt.Errorf("failed to save token: %w", err) + } + + // Also save refresh token separately for backup recovery + if token.RefreshToken != "" { + refreshTokenPath := filepath.Join(TokenStorageDir, RefreshTokenFile) + if err := os.WriteFile(refreshTokenPath, []byte(token.RefreshToken), 0600); err != nil { + // Don't fail if refresh token backup fails, just log + fmt.Printf("Warning: Failed to save backup refresh token: %v\n", err) + } + } + + return nil +} // LoadToken loads the authentication token from secure local storage +func (am *AuthManager) LoadToken() (*types.AuthToken, error) { + tokenPath := am.getTokenPath() + + data, err := os.ReadFile(tokenPath) + if err != nil { + return nil, fmt.Errorf("failed to read token file: %w", err) + } + + var token types.AuthToken + if err := json.Unmarshal(data, &token); err != nil { + return nil, fmt.Errorf("failed to parse token: %w", err) + } + + // Check if token is expired + if time.Now().After(token.ExpiresAt.Add(-5 * time.Minute)) { + return nil, fmt.Errorf("token is expired or expiring soon") + } + + return &token, nil +} + +// IsTokenExpired checks if a token needs refresh +func (am *AuthManager) IsTokenExpired(token *types.AuthToken) bool { + // Consider token expired if it expires within the next 5 minutes + return time.Now().After(token.ExpiresAt.Add(-5 * time.Minute)) +} + +// RegisterDevice performs the complete device registration flow +func (am *AuthManager) RegisterDevice() (*types.AuthToken, error) { + // Step 1: Start device authorization + deviceAuth, err := am.StartDeviceAuthorization() + if err != nil { + return nil, fmt.Errorf("failed to start device authorization: %w", err) + } + + fmt.Printf("Please visit: %s\n", deviceAuth.VerificationURI) + fmt.Printf("And enter code: %s\n", deviceAuth.UserCode) + + // Step 2: Poll for token + tokenResp, err := am.PollForToken(deviceAuth.DeviceCode) + if err != nil { + return nil, fmt.Errorf("failed to get token: %w", err) + } + + // Step 3: Create token storage + token := &types.AuthToken{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + TokenType: tokenResp.TokenType, + ExpiresAt: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second), + AgentID: tokenResp.AgentID, + } + + // Step 4: Save token + if err := am.SaveToken(token); err != nil { + return nil, fmt.Errorf("failed to save token: %w", err) + } + + return token, nil +} + +// EnsureAuthenticated ensures the agent has a valid token, refreshing if necessary +func (am *AuthManager) EnsureAuthenticated() (*types.AuthToken, error) { + // Try to load existing token + token, err := am.LoadToken() + if err == nil && !am.IsTokenExpired(token) { + return token, nil + } + + // Try to refresh with existing refresh token (even if access token is missing/expired) + var refreshToken string + if err == nil && token.RefreshToken != "" { + // Use refresh token from loaded token + refreshToken = token.RefreshToken + } else { + // Try to load refresh token from main token file even if load failed + if existingToken, loadErr := am.loadTokenIgnoringExpiry(); loadErr == nil && existingToken.RefreshToken != "" { + refreshToken = existingToken.RefreshToken + } else { + // Try to load refresh token from backup file + if backupRefreshToken, backupErr := am.loadRefreshTokenFromBackup(); backupErr == nil { + refreshToken = backupRefreshToken + fmt.Println("šŸ”„ Found backup refresh token, attempting to use it...") + } + } + } + + if refreshToken != "" { + fmt.Println("šŸ”„ Attempting to refresh access token...") + + refreshResp, refreshErr := am.RefreshAccessToken(refreshToken) + if refreshErr == nil { + // Get existing agent_id from current token or backup + var agentID string + if err == nil && token.AgentID != "" { + agentID = token.AgentID + } else if existingToken, loadErr := am.loadTokenIgnoringExpiry(); loadErr == nil { + agentID = existingToken.AgentID + } + + // Create new token with refreshed values + newToken := &types.AuthToken{ + AccessToken: refreshResp.AccessToken, + RefreshToken: refreshToken, // Keep existing refresh token + TokenType: refreshResp.TokenType, + ExpiresAt: time.Now().Add(time.Duration(refreshResp.ExpiresIn) * time.Second), + AgentID: agentID, // Preserve agent_id + } + + // Update refresh token if a new one was provided + if refreshResp.RefreshToken != "" { + newToken.RefreshToken = refreshResp.RefreshToken + } + + if saveErr := am.SaveToken(newToken); saveErr == nil { + return newToken, nil + } + } else { + fmt.Printf("āš ļø Token refresh failed: %v\n", refreshErr) + } + } + + fmt.Println("šŸ“ Initiating new device registration...") + return am.RegisterDevice() +} + +// loadTokenIgnoringExpiry loads token file without checking expiry +func (am *AuthManager) loadTokenIgnoringExpiry() (*types.AuthToken, error) { + tokenPath := am.getTokenPath() + + data, err := os.ReadFile(tokenPath) + if err != nil { + return nil, fmt.Errorf("failed to read token file: %w", err) + } + + var token types.AuthToken + if err := json.Unmarshal(data, &token); err != nil { + return nil, fmt.Errorf("failed to parse token: %w", err) + } + + return &token, nil +} + +// loadRefreshTokenFromBackup tries to load refresh token from backup file +func (am *AuthManager) loadRefreshTokenFromBackup() (string, error) { + refreshTokenPath := filepath.Join(TokenStorageDir, RefreshTokenFile) + + data, err := os.ReadFile(refreshTokenPath) + if err != nil { + return "", fmt.Errorf("failed to read refresh token backup: %w", err) + } + + refreshToken := strings.TrimSpace(string(data)) + if refreshToken == "" { + return "", fmt.Errorf("refresh token backup is empty") + } + + return refreshToken, nil +} + +func (am *AuthManager) getTokenPath() string { + if am.config.TokenPath != "" { + return am.config.TokenPath + } + return filepath.Join(TokenStorageDir, TokenStorageFile) +} + +func getHostname() string { + if hostname, err := os.Hostname(); err == nil { + return hostname + } + return "unknown" +} diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..2229bbb --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,131 @@ +package config + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/joho/godotenv" +) + +type Config struct { + // Supabase Configuration + SupabaseProjectURL string + + // Edge Function Endpoints (auto-generated from SupabaseProjectURL) + DeviceAuthURL string + AgentAuthURL string + + // Agent Configuration + TokenPath string + MetricsInterval int + + // Debug/Development + Debug bool +} + +var DefaultConfig = Config{ + TokenPath: "./token.json", + MetricsInterval: 30, + Debug: false, +} + +// LoadConfig loads configuration from environment variables and .env file +func LoadConfig() (*Config, error) { + config := DefaultConfig + + // Try to load .env file from current directory or parent directories + envFile := findEnvFile() + if envFile != "" { + if err := godotenv.Load(envFile); err != nil { + fmt.Printf("Warning: Could not load .env file from %s: %v\n", envFile, err) + } else { + fmt.Printf("Loaded configuration from %s\n", envFile) + } + } + + // Load from environment variables + if url := os.Getenv("SUPABASE_PROJECT_URL"); url != "" { + config.SupabaseProjectURL = url + } + + if tokenPath := os.Getenv("TOKEN_PATH"); tokenPath != "" { + config.TokenPath = tokenPath + } + + if debug := os.Getenv("DEBUG"); debug == "true" || debug == "1" { + config.Debug = true + } + + // Auto-generate edge function URLs from project URL + if config.SupabaseProjectURL != "" { + config.DeviceAuthURL = fmt.Sprintf("%s/functions/v1/device-auth", config.SupabaseProjectURL) + config.AgentAuthURL = fmt.Sprintf("%s/functions/v1/agent-auth-api", config.SupabaseProjectURL) + } + + // Validate required configuration + if err := config.Validate(); err != nil { + return nil, fmt.Errorf("configuration validation failed: %w", err) + } + + return &config, nil +} + +// Validate checks if all required configuration is present +func (c *Config) Validate() error { + var missing []string + + if c.SupabaseProjectURL == "" { + missing = append(missing, "SUPABASE_PROJECT_URL") + } + + if c.DeviceAuthURL == "" { + missing = append(missing, "DEVICE_AUTH_URL (or SUPABASE_PROJECT_URL)") + } + + if c.AgentAuthURL == "" { + missing = append(missing, "AGENT_AUTH_URL (or SUPABASE_PROJECT_URL)") + } + + if len(missing) > 0 { + return fmt.Errorf("missing required environment variables: %s", strings.Join(missing, ", ")) + } + + return nil +} + +// findEnvFile looks for .env file in current directory and parent directories +func findEnvFile() string { + dir, err := os.Getwd() + if err != nil { + return "" + } + + for { + envPath := filepath.Join(dir, ".env") + if _, err := os.Stat(envPath); err == nil { + return envPath + } + + parent := filepath.Dir(dir) + if parent == dir { + break + } + dir = parent + } + + return "" +} + +// PrintConfig prints the current configuration (masking sensitive values) +func (c *Config) PrintConfig() { + if !c.Debug { + return + } + + fmt.Println("Configuration:") + fmt.Printf(" Supabase Project URL: %s\n", c.SupabaseProjectURL) + fmt.Printf(" Metrics Interval: %d seconds\n", c.MetricsInterval) + fmt.Printf(" Debug: %v\n", c.Debug) +} diff --git a/internal/metrics/collector.go b/internal/metrics/collector.go new file mode 100644 index 0000000..fa01e40 --- /dev/null +++ b/internal/metrics/collector.go @@ -0,0 +1,315 @@ +package metrics + +import ( + "bytes" + "crypto/sha256" + "encoding/json" + "fmt" + "io" + "math" + "net/http" + "strings" + "time" + + "github.com/shirou/gopsutil/v3/cpu" + "github.com/shirou/gopsutil/v3/disk" + "github.com/shirou/gopsutil/v3/host" + "github.com/shirou/gopsutil/v3/load" + "github.com/shirou/gopsutil/v3/mem" + psnet "github.com/shirou/gopsutil/v3/net" + + "nannyagentv2/internal/types" +) + +// Collector handles system metrics collection +type Collector struct { + agentVersion string +} + +// NewCollector creates a new metrics collector +func NewCollector(agentVersion string) *Collector { + return &Collector{ + agentVersion: agentVersion, + } +} + +// GatherSystemMetrics collects comprehensive system metrics +func (c *Collector) GatherSystemMetrics() (*types.SystemMetrics, error) { + metrics := &types.SystemMetrics{ + Timestamp: time.Now(), + } + + // System Information + if hostInfo, err := host.Info(); err == nil { + metrics.Hostname = hostInfo.Hostname + metrics.Platform = hostInfo.Platform + metrics.PlatformFamily = hostInfo.PlatformFamily + metrics.PlatformVersion = hostInfo.PlatformVersion + metrics.KernelVersion = hostInfo.KernelVersion + metrics.KernelArch = hostInfo.KernelArch + } + + // CPU Metrics + if percentages, err := cpu.Percent(time.Second, false); err == nil && len(percentages) > 0 { + metrics.CPUUsage = math.Round(percentages[0]*100) / 100 + } + + if cpuInfo, err := cpu.Info(); err == nil && len(cpuInfo) > 0 { + metrics.CPUCores = len(cpuInfo) + metrics.CPUModel = cpuInfo[0].ModelName + } + + // Memory Metrics + if memInfo, err := mem.VirtualMemory(); err == nil { + metrics.MemoryUsage = math.Round(float64(memInfo.Used)/(1024*1024)*100) / 100 // MB + metrics.MemoryTotal = memInfo.Total + metrics.MemoryUsed = memInfo.Used + metrics.MemoryFree = memInfo.Free + metrics.MemoryAvailable = memInfo.Available + } + + if swapInfo, err := mem.SwapMemory(); err == nil { + metrics.SwapTotal = swapInfo.Total + metrics.SwapUsed = swapInfo.Used + metrics.SwapFree = swapInfo.Free + } + + // Disk Metrics + if diskInfo, err := disk.Usage("/"); err == nil { + metrics.DiskUsage = math.Round(diskInfo.UsedPercent*100) / 100 + metrics.DiskTotal = diskInfo.Total + metrics.DiskUsed = diskInfo.Used + metrics.DiskFree = diskInfo.Free + } + + // Load Averages + if loadAvg, err := load.Avg(); err == nil { + metrics.LoadAvg1 = math.Round(loadAvg.Load1*100) / 100 + metrics.LoadAvg5 = math.Round(loadAvg.Load5*100) / 100 + metrics.LoadAvg15 = math.Round(loadAvg.Load15*100) / 100 + } + + // Process Count (simplified - using a constant for now) + // Note: gopsutil doesn't have host.Processes(), would need process.Processes() + metrics.ProcessCount = 0 // Placeholder + + // Network Metrics + netIn, netOut := c.getNetworkStats() + metrics.NetworkInKbps = netIn + metrics.NetworkOutKbps = netOut + + if netIOCounters, err := psnet.IOCounters(false); err == nil && len(netIOCounters) > 0 { + netIO := netIOCounters[0] + metrics.NetworkInBytes = netIO.BytesRecv + metrics.NetworkOutBytes = netIO.BytesSent + } + + // IP Address and Location + metrics.IPAddress = c.getIPAddress() + metrics.Location = c.getLocation() // Placeholder + + // Filesystem Information + metrics.FilesystemInfo = c.getFilesystemInfo() + + // Block Devices + metrics.BlockDevices = c.getBlockDevices() + + return metrics, nil +} + +// getNetworkStats returns network input/output rates in Kbps +func (c *Collector) getNetworkStats() (float64, float64) { + netIOCounters, err := psnet.IOCounters(false) + if err != nil || len(netIOCounters) == 0 { + return 0.0, 0.0 + } + + // Use the first interface for aggregate stats + netIO := netIOCounters[0] + + // Convert bytes to kilobits per second (simplified - cumulative bytes to kilobits) + netInKbps := float64(netIO.BytesRecv) * 8 / 1024 + netOutKbps := float64(netIO.BytesSent) * 8 / 1024 + + return netInKbps, netOutKbps +} + +// getIPAddress returns the primary IP address of the system +func (c *Collector) getIPAddress() string { + interfaces, err := psnet.Interfaces() + if err != nil { + return "unknown" + } + + for _, iface := range interfaces { + if len(iface.Addrs) > 0 && !strings.Contains(iface.Addrs[0].Addr, "127.0.0.1") { + return strings.Split(iface.Addrs[0].Addr, "/")[0] // Remove CIDR if present + } + } + + return "unknown" +} + +// getLocation returns basic location information (placeholder) +func (c *Collector) getLocation() string { + return "unknown" // Would integrate with GeoIP service +} + +// getFilesystemInfo returns information about mounted filesystems +func (c *Collector) getFilesystemInfo() []types.FilesystemInfo { + partitions, err := disk.Partitions(false) + if err != nil { + return []types.FilesystemInfo{} + } + + var filesystems []types.FilesystemInfo + for _, partition := range partitions { + usage, err := disk.Usage(partition.Mountpoint) + if err != nil { + continue + } + + fs := types.FilesystemInfo{ + Mountpoint: partition.Mountpoint, + Fstype: partition.Fstype, + Total: usage.Total, + Used: usage.Used, + Free: usage.Free, + UsagePercent: math.Round(usage.UsedPercent*100) / 100, + } + filesystems = append(filesystems, fs) + } + + return filesystems +} + +// getBlockDevices returns information about block devices +func (c *Collector) getBlockDevices() []types.BlockDevice { + partitions, err := disk.Partitions(true) + if err != nil { + return []types.BlockDevice{} + } + + var devices []types.BlockDevice + deviceMap := make(map[string]bool) + + for _, partition := range partitions { + // Only include actual block devices + if strings.HasPrefix(partition.Device, "/dev/") { + deviceName := partition.Device + if !deviceMap[deviceName] { + deviceMap[deviceName] = true + + device := types.BlockDevice{ + Name: deviceName, + Model: "unknown", + Size: 0, + SerialNumber: "unknown", + } + devices = append(devices, device) + } + } + } + + return devices +} + +// SendMetrics sends system metrics to the agent-auth-api endpoint +func (c *Collector) SendMetrics(agentAuthURL, accessToken, agentID string, metrics *types.SystemMetrics) error { + // Create flattened metrics request for agent-auth-api + metricsReq := c.CreateMetricsRequest(agentID, metrics) + + return c.sendMetricsRequest(agentAuthURL, accessToken, metricsReq) +} + +// CreateMetricsRequest converts SystemMetrics to the flattened format expected by agent-auth-api +func (c *Collector) CreateMetricsRequest(agentID string, systemMetrics *types.SystemMetrics) *types.MetricsRequest { + return &types.MetricsRequest{ + AgentID: agentID, + CPUUsage: systemMetrics.CPUUsage, + MemoryUsage: systemMetrics.MemoryUsage, + DiskUsage: systemMetrics.DiskUsage, + NetworkInKbps: systemMetrics.NetworkInKbps, + NetworkOutKbps: systemMetrics.NetworkOutKbps, + IPAddress: systemMetrics.IPAddress, + Location: systemMetrics.Location, + AgentVersion: c.agentVersion, + KernelVersion: systemMetrics.KernelVersion, + DeviceFingerprint: c.generateDeviceFingerprint(systemMetrics), + LoadAverages: map[string]float64{ + "load1": systemMetrics.LoadAvg1, + "load5": systemMetrics.LoadAvg5, + "load15": systemMetrics.LoadAvg15, + }, + OSInfo: map[string]string{ + "platform": systemMetrics.Platform, + "platform_family": systemMetrics.PlatformFamily, + "platform_version": systemMetrics.PlatformVersion, + "kernel_version": systemMetrics.KernelVersion, + "kernel_arch": systemMetrics.KernelArch, + }, + FilesystemInfo: systemMetrics.FilesystemInfo, + BlockDevices: systemMetrics.BlockDevices, + NetworkStats: map[string]uint64{ + "bytes_sent": systemMetrics.NetworkOutBytes, + "bytes_recv": systemMetrics.NetworkInBytes, + "total_bytes": systemMetrics.NetworkInBytes + systemMetrics.NetworkOutBytes, + }, + } +} + +// sendMetricsRequest sends the metrics request to the agent-auth-api +func (c *Collector) sendMetricsRequest(agentAuthURL, accessToken string, metricsReq *types.MetricsRequest) error { + // Wrap metrics in the expected payload structure + payload := map[string]interface{}{ + "metrics": metricsReq, + "timestamp": time.Now().UTC().Format(time.RFC3339), + } + + jsonData, err := json.Marshal(payload) + if err != nil { + return fmt.Errorf("failed to marshal metrics: %w", err) + } + + // Send to /metrics endpoint + metricsURL := fmt.Sprintf("%s/metrics", agentAuthURL) + req, err := http.NewRequest("POST", metricsURL, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("failed to send metrics: %w", err) + } + defer resp.Body.Close() + + // Read response + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed to read response: %w", err) + } + + // Check response status + if resp.StatusCode == http.StatusUnauthorized { + return fmt.Errorf("unauthorized") + } + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("metrics request failed with status %d: %s", resp.StatusCode, string(body)) + } + + return nil +} + +// generateDeviceFingerprint creates a unique device identifier +func (c *Collector) generateDeviceFingerprint(metrics *types.SystemMetrics) string { + fingerprint := fmt.Sprintf("%s-%s-%s", metrics.Hostname, metrics.Platform, metrics.KernelVersion) + hasher := sha256.New() + hasher.Write([]byte(fingerprint)) + return fmt.Sprintf("%x", hasher.Sum(nil))[:16] +} diff --git a/internal/types/types.go b/internal/types/types.go new file mode 100644 index 0000000..1016da0 --- /dev/null +++ b/internal/types/types.go @@ -0,0 +1,170 @@ +package types + +import "time" + +// SystemMetrics represents comprehensive system performance metrics +type SystemMetrics struct { + // System Information + Hostname string `json:"hostname"` + Platform string `json:"platform"` + PlatformFamily string `json:"platform_family"` + PlatformVersion string `json:"platform_version"` + KernelVersion string `json:"kernel_version"` + KernelArch string `json:"kernel_arch"` + + // CPU Metrics + CPUUsage float64 `json:"cpu_usage"` + CPUCores int `json:"cpu_cores"` + CPUModel string `json:"cpu_model"` + + // Memory Metrics + MemoryUsage float64 `json:"memory_usage"` + MemoryTotal uint64 `json:"memory_total"` + MemoryUsed uint64 `json:"memory_used"` + MemoryFree uint64 `json:"memory_free"` + MemoryAvailable uint64 `json:"memory_available"` + SwapTotal uint64 `json:"swap_total"` + SwapUsed uint64 `json:"swap_used"` + SwapFree uint64 `json:"swap_free"` + + // Disk Metrics + DiskUsage float64 `json:"disk_usage"` + DiskTotal uint64 `json:"disk_total"` + DiskUsed uint64 `json:"disk_used"` + DiskFree uint64 `json:"disk_free"` + + // Network Metrics + NetworkInKbps float64 `json:"network_in_kbps"` + NetworkOutKbps float64 `json:"network_out_kbps"` + NetworkInBytes uint64 `json:"network_in_bytes"` + NetworkOutBytes uint64 `json:"network_out_bytes"` + + // System Load + LoadAvg1 float64 `json:"load_avg_1"` + LoadAvg5 float64 `json:"load_avg_5"` + LoadAvg15 float64 `json:"load_avg_15"` + + // Process Information + ProcessCount int `json:"process_count"` + + // Network Information + IPAddress string `json:"ip_address"` + Location string `json:"location"` + + // Filesystem Information + FilesystemInfo []FilesystemInfo `json:"filesystem_info"` + BlockDevices []BlockDevice `json:"block_devices"` + + // Timestamp + Timestamp time.Time `json:"timestamp"` +} + +// FilesystemInfo represents individual filesystem statistics +type FilesystemInfo struct { + Mountpoint string `json:"mountpoint"` + Fstype string `json:"fstype"` + Total uint64 `json:"total"` + Used uint64 `json:"used"` + Free uint64 `json:"free"` + UsagePercent float64 `json:"usage_percent"` +} + +// BlockDevice represents block device information +type BlockDevice struct { + Name string `json:"name"` + Size uint64 `json:"size"` + Model string `json:"model"` + SerialNumber string `json:"serial_number"` +} + +// NetworkStats represents detailed network interface statistics +type NetworkStats struct { + InterfaceName string `json:"interface_name"` + BytesSent uint64 `json:"bytes_sent"` + BytesRecv uint64 `json:"bytes_recv"` + PacketsSent uint64 `json:"packets_sent"` + PacketsRecv uint64 `json:"packets_recv"` + ErrorsIn uint64 `json:"errors_in"` + ErrorsOut uint64 `json:"errors_out"` + DropsIn uint64 `json:"drops_in"` + DropsOut uint64 `json:"drops_out"` +} + +// AuthToken represents the authentication token structure +type AuthToken struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresAt time.Time `json:"expires_at"` + TokenType string `json:"token_type"` + AgentID string `json:"agent_id"` +} + +// DeviceAuthRequest represents the device authorization request +type DeviceAuthRequest struct { + ClientID string `json:"client_id"` + Scope string `json:"scope,omitempty"` +} + +// DeviceAuthResponse represents the device authorization response +type DeviceAuthResponse struct { + DeviceCode string `json:"device_code"` + UserCode string `json:"user_code"` + VerificationURI string `json:"verification_uri"` + ExpiresIn int `json:"expires_in"` + Interval int `json:"interval"` +} + +// TokenRequest represents the token request for device flow +type TokenRequest struct { + GrantType string `json:"grant_type"` + DeviceCode string `json:"device_code,omitempty"` + RefreshToken string `json:"refresh_token,omitempty"` + ClientID string `json:"client_id,omitempty"` +} + +// TokenResponse represents the token response +type TokenResponse struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + AgentID string `json:"agent_id,omitempty"` + Error string `json:"error,omitempty"` + ErrorDescription string `json:"error_description,omitempty"` +} + +// HeartbeatRequest represents the agent heartbeat request +type HeartbeatRequest struct { + AgentID string `json:"agent_id"` + Status string `json:"status"` + Metrics SystemMetrics `json:"metrics"` +} + +// MetricsRequest represents the flattened metrics payload expected by agent-auth-api +type MetricsRequest struct { + // Agent identification + AgentID string `json:"agent_id"` + + // Basic metrics + CPUUsage float64 `json:"cpu_usage"` + MemoryUsage float64 `json:"memory_usage"` + DiskUsage float64 `json:"disk_usage"` + + // Network metrics + NetworkInKbps float64 `json:"network_in_kbps"` + NetworkOutKbps float64 `json:"network_out_kbps"` + + // System information + IPAddress string `json:"ip_address"` + Location string `json:"location"` + AgentVersion string `json:"agent_version"` + KernelVersion string `json:"kernel_version"` + DeviceFingerprint string `json:"device_fingerprint"` + + // Structured data (JSON fields in database) + LoadAverages map[string]float64 `json:"load_averages"` + OSInfo map[string]string `json:"os_info"` + FilesystemInfo []FilesystemInfo `json:"filesystem_info"` + BlockDevices []BlockDevice `json:"block_devices"` + NetworkStats map[string]uint64 `json:"network_stats"` +} diff --git a/main.go b/main.go index dc58824..fc92087 100644 --- a/main.go +++ b/main.go @@ -1,143 +1,101 @@ package main import ( - "bufio" "fmt" "log" - "os" - "os/exec" - "strconv" - "strings" - "syscall" + "time" + + "nannyagentv2/internal/auth" + "nannyagentv2/internal/config" + "nannyagentv2/internal/metrics" + "nannyagentv2/internal/types" ) -// checkRootPrivileges ensures the program is running as root -func checkRootPrivileges() { - if os.Geteuid() != 0 { - fmt.Fprintf(os.Stderr, "āŒ ERROR: This program must be run as root for eBPF functionality.\n") - fmt.Fprintf(os.Stderr, "Please run with: sudo %s\n", os.Args[0]) - fmt.Fprintf(os.Stderr, "Reason: eBPF programs require root privileges to:\n") - fmt.Fprintf(os.Stderr, " - Load programs into the kernel\n") - fmt.Fprintf(os.Stderr, " - Attach to kernel functions and tracepoints\n") - fmt.Fprintf(os.Stderr, " - Access kernel memory maps\n") - os.Exit(1) - } -} - -// checkKernelVersionCompatibility ensures kernel version is 4.4 or higher -func checkKernelVersionCompatibility() { - output, err := exec.Command("uname", "-r").Output() - if err != nil { - fmt.Fprintf(os.Stderr, "āŒ ERROR: Cannot determine kernel version: %v\n", err) - os.Exit(1) - } - - kernelVersion := strings.TrimSpace(string(output)) - - // Parse version (e.g., "5.15.0-56-generic" -> major=5, minor=15) - parts := strings.Split(kernelVersion, ".") - if len(parts) < 2 { - fmt.Fprintf(os.Stderr, "āŒ ERROR: Cannot parse kernel version: %s\n", kernelVersion) - os.Exit(1) - } - - major, err := strconv.Atoi(parts[0]) - if err != nil { - fmt.Fprintf(os.Stderr, "āŒ ERROR: Cannot parse major kernel version: %s\n", parts[0]) - os.Exit(1) - } - - minor, err := strconv.Atoi(parts[1]) - if err != nil { - fmt.Fprintf(os.Stderr, "āŒ ERROR: Cannot parse minor kernel version: %s\n", parts[1]) - os.Exit(1) - } - - // Check if kernel is 4.4 or higher - if major < 4 || (major == 4 && minor < 4) { - fmt.Fprintf(os.Stderr, "āŒ ERROR: Kernel version %s is too old for eBPF.\n", kernelVersion) - fmt.Fprintf(os.Stderr, "Required: Linux kernel 4.4 or higher\n") - fmt.Fprintf(os.Stderr, "Current: %s\n", kernelVersion) - fmt.Fprintf(os.Stderr, "Reason: eBPF requires kernel features introduced in 4.4+:\n") - fmt.Fprintf(os.Stderr, " - BPF system call support\n") - fmt.Fprintf(os.Stderr, " - eBPF program types (kprobe, tracepoint)\n") - fmt.Fprintf(os.Stderr, " - BPF maps and helper functions\n") - os.Exit(1) - } - - fmt.Printf("āœ… Kernel version %s is compatible with eBPF\n", kernelVersion) -} - -// checkEBPFSupport validates eBPF subsystem availability -func checkEBPFSupport() { - // Check if /sys/kernel/debug/tracing exists (debugfs mounted) - if _, err := os.Stat("/sys/kernel/debug/tracing"); os.IsNotExist(err) { - fmt.Fprintf(os.Stderr, "āš ļø WARNING: debugfs not mounted. Some eBPF features may not work.\n") - fmt.Fprintf(os.Stderr, "To fix: sudo mount -t debugfs debugfs /sys/kernel/debug\n") - } - - // Check if we can access BPF syscall - fd, _, errno := syscall.Syscall(321, 0, 0, 0) // BPF syscall number on x86_64 - if errno != 0 && errno != syscall.EINVAL { - fmt.Fprintf(os.Stderr, "āŒ ERROR: BPF syscall not available (errno: %v)\n", errno) - fmt.Fprintf(os.Stderr, "This may indicate:\n") - fmt.Fprintf(os.Stderr, " - Kernel compiled without BPF support\n") - fmt.Fprintf(os.Stderr, " - BPF syscall disabled in kernel config\n") - os.Exit(1) - } - if fd > 0 { - syscall.Close(int(fd)) - } - - fmt.Printf("āœ… eBPF syscall is available\n") -} +const Version = "v2.0.0" func main() { - fmt.Println("šŸ” Linux eBPF-Enhanced Diagnostic Agent") - fmt.Println("=======================================") + fmt.Printf("šŸš€ NannyAgent v%s starting...\n", Version) - // Perform system compatibility checks - fmt.Println("Performing system compatibility checks...") + // Load configuration + cfg, err := config.LoadConfig() + if err != nil { + log.Fatalf("āŒ Failed to load configuration: %v", err) + } - checkRootPrivileges() - checkKernelVersionCompatibility() - checkEBPFSupport() + cfg.PrintConfig() - fmt.Println("āœ… All system checks passed") - fmt.Println("") + // Initialize components + authManager := auth.NewAuthManager(cfg) + metricsCollector := metrics.NewCollector(Version) - // Initialize the agent - agent := NewLinuxDiagnosticAgent() + // Ensure authentication + token, err := authManager.EnsureAuthenticated() + if err != nil { + log.Fatalf("āŒ Authentication failed: %v", err) + } - // Start the interactive session - fmt.Println("Linux Diagnostic Agent Started") - fmt.Println("Enter a system issue description (or 'quit' to exit):") + fmt.Println("āœ… Authentication successful!") - scanner := bufio.NewScanner(os.Stdin) + // Start metrics collection and heartbeat loop + fmt.Println("ā¤ļø Starting metrics collection and heartbeat...") + + ticker := time.NewTicker(time.Duration(cfg.MetricsInterval) * time.Second) + defer ticker.Stop() + + // Send initial heartbeat + if err := sendHeartbeat(cfg, token, metricsCollector); err != nil { + log.Printf("āš ļø Initial heartbeat failed: %v", err) + } + + // Main heartbeat loop for { - fmt.Print("> ") - if !scanner.Scan() { - break - } + select { + case <-ticker.C: + // Check if token needs refresh + if authManager.IsTokenExpired(token) { + fmt.Println("šŸ”„ Token expiring soon, refreshing...") + newToken, refreshErr := authManager.EnsureAuthenticated() + if refreshErr != nil { + log.Printf("āŒ Token refresh failed: %v", refreshErr) + continue + } + token = newToken + fmt.Println("āœ… Token refreshed successfully") + } - input := strings.TrimSpace(scanner.Text()) - if input == "quit" || input == "exit" { - break - } + // Send heartbeat + if err := sendHeartbeat(cfg, token, metricsCollector); err != nil { + log.Printf("āš ļø Heartbeat failed: %v", err) - if input == "" { - continue - } + // If unauthorized, try to refresh token + if err.Error() == "unauthorized" { + fmt.Println("šŸ”„ Unauthorized, attempting token refresh...") + newToken, refreshErr := authManager.EnsureAuthenticated() + if refreshErr != nil { + log.Printf("āŒ Token refresh failed: %v", refreshErr) + continue + } + token = newToken - // Process the issue with eBPF capabilities - if err := agent.DiagnoseWithEBPF(input); err != nil { - fmt.Printf("Error: %v\n", err) + // Retry heartbeat with new token (silently) + if retryErr := sendHeartbeat(cfg, token, metricsCollector); retryErr != nil { + log.Printf("āš ļø Retry heartbeat failed: %v", retryErr) + } + } + } + // No logging for successful heartbeats - they should be silent } } - - if err := scanner.Err(); err != nil { - log.Fatal(err) - } - - fmt.Println("Goodbye!") +} + +// sendHeartbeat collects metrics and sends heartbeat to the server +func sendHeartbeat(cfg *config.Config, token *types.AuthToken, collector *metrics.Collector) error { + // Collect system metrics + systemMetrics, err := collector.GatherSystemMetrics() + if err != nil { + return fmt.Errorf("failed to gather system metrics: %w", err) + } + + // Send metrics using the collector with correct agent_id from token + return collector.SendMetrics(cfg.AgentAuthURL, token.AccessToken, token.AgentID, systemMetrics) } diff --git a/demo_ebpf_integration.sh b/scripts/demo_ebpf_integration.sh similarity index 100% rename from demo_ebpf_integration.sh rename to scripts/demo_ebpf_integration.sh diff --git a/discover-functions.sh b/scripts/discover-functions.sh similarity index 100% rename from discover-functions.sh rename to scripts/discover-functions.sh diff --git a/ebpf_helper.sh b/scripts/ebpf_helper.sh similarity index 100% rename from ebpf_helper.sh rename to scripts/ebpf_helper.sh diff --git a/install.sh b/scripts/install.sh similarity index 100% rename from install.sh rename to scripts/install.sh diff --git a/integration-tests.sh b/scripts/integration-tests.sh similarity index 100% rename from integration-tests.sh rename to scripts/integration-tests.sh