package auth import ( "bytes" "encoding/json" "fmt" "io" "net/http" "os" "path/filepath" "strings" "time" "nannyagentv2/internal/config" "nannyagentv2/internal/types" ) const ( // Token storage location (secure directory) TokenStorageDir = "/var/lib/nannyagent" TokenStorageFile = ".agent_token.json" RefreshTokenFile = ".refresh_token" // Polling configuration MaxPollAttempts = 60 // 5 minutes (60 * 5 seconds) PollInterval = 5 * time.Second ) // AuthManager handles all authentication-related operations type AuthManager struct { config *config.Config client *http.Client } // NewAuthManager creates a new authentication manager func NewAuthManager(cfg *config.Config) *AuthManager { return &AuthManager{ config: cfg, client: &http.Client{ Timeout: 30 * time.Second, }, } } // EnsureTokenStorageDir creates the token storage directory if it doesn't exist func (am *AuthManager) EnsureTokenStorageDir() error { // Check if running as root if os.Geteuid() != 0 { return fmt.Errorf("must run as root to create secure token storage directory") } // Create directory with restricted permissions (0700 - only root can access) if err := os.MkdirAll(TokenStorageDir, 0700); err != nil { return fmt.Errorf("failed to create token storage directory: %w", err) } return nil } // StartDeviceAuthorization initiates the OAuth device authorization flow func (am *AuthManager) StartDeviceAuthorization() (*types.DeviceAuthResponse, error) { payload := map[string]interface{}{ "client_id": "nannyagent-cli", "scope": []string{"agent:register"}, } jsonData, err := json.Marshal(payload) if err != nil { return nil, fmt.Errorf("failed to marshal payload: %w", err) } url := fmt.Sprintf("%s/device/authorize", am.config.DeviceAuthURL) req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } req.Header.Set("Content-Type", "application/json") resp, err := am.client.Do(req) if err != nil { return nil, fmt.Errorf("failed to start device authorization: %w", err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) } if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("device authorization failed with status %d: %s", resp.StatusCode, string(body)) } var deviceResp types.DeviceAuthResponse if err := json.Unmarshal(body, &deviceResp); err != nil { return nil, fmt.Errorf("failed to parse response: %w", err) } return &deviceResp, nil } // PollForToken polls the token endpoint until authorization is complete func (am *AuthManager) PollForToken(deviceCode string) (*types.TokenResponse, error) { fmt.Println("ā³ Waiting for user authorization...") for attempts := 0; attempts < MaxPollAttempts; attempts++ { tokenReq := types.TokenRequest{ GrantType: "urn:ietf:params:oauth:grant-type:device_code", DeviceCode: deviceCode, } jsonData, err := json.Marshal(tokenReq) if err != nil { return nil, fmt.Errorf("failed to marshal token request: %w", err) } url := fmt.Sprintf("%s/token", am.config.DeviceAuthURL) req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) if err != nil { return nil, fmt.Errorf("failed to create token request: %w", err) } req.Header.Set("Content-Type", "application/json") resp, err := am.client.Do(req) if err != nil { return nil, fmt.Errorf("failed to poll for token: %w", err) } body, err := io.ReadAll(resp.Body) resp.Body.Close() if err != nil { return nil, fmt.Errorf("failed to read token response: %w", err) } var tokenResp types.TokenResponse if err := json.Unmarshal(body, &tokenResp); err != nil { return nil, fmt.Errorf("failed to parse token response: %w", err) } if tokenResp.Error != "" { if tokenResp.Error == "authorization_pending" { fmt.Print(".") time.Sleep(PollInterval) continue } return nil, fmt.Errorf("authorization failed: %s", tokenResp.ErrorDescription) } if tokenResp.AccessToken != "" { fmt.Println("\nāœ… Authorization successful!") return &tokenResp, nil } time.Sleep(PollInterval) } return nil, fmt.Errorf("authorization timed out after %d attempts", MaxPollAttempts) } // RefreshAccessToken refreshes an expired access token using the refresh token func (am *AuthManager) RefreshAccessToken(refreshToken string) (*types.TokenResponse, error) { tokenReq := types.TokenRequest{ GrantType: "refresh_token", RefreshToken: refreshToken, } jsonData, err := json.Marshal(tokenReq) if err != nil { return nil, fmt.Errorf("failed to marshal refresh request: %w", err) } url := fmt.Sprintf("%s/token", am.config.DeviceAuthURL) req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) if err != nil { return nil, fmt.Errorf("failed to create refresh request: %w", err) } req.Header.Set("Content-Type", "application/json") resp, err := am.client.Do(req) if err != nil { return nil, fmt.Errorf("failed to refresh token: %w", err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read refresh response: %w", err) } if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("token refresh failed with status %d: %s", resp.StatusCode, string(body)) } var tokenResp types.TokenResponse if err := json.Unmarshal(body, &tokenResp); err != nil { return nil, fmt.Errorf("failed to parse refresh response: %w", err) } if tokenResp.Error != "" { return nil, fmt.Errorf("token refresh failed: %s", tokenResp.ErrorDescription) } return &tokenResp, nil } // SaveToken saves the authentication token to secure local storage func (am *AuthManager) SaveToken(token *types.AuthToken) error { if err := am.EnsureTokenStorageDir(); err != nil { return fmt.Errorf("failed to ensure token storage directory: %w", err) } // Save main token file tokenPath := am.getTokenPath() jsonData, err := json.MarshalIndent(token, "", " ") if err != nil { return fmt.Errorf("failed to marshal token: %w", err) } if err := os.WriteFile(tokenPath, jsonData, 0600); err != nil { return fmt.Errorf("failed to save token: %w", err) } // Also save refresh token separately for backup recovery if token.RefreshToken != "" { refreshTokenPath := filepath.Join(TokenStorageDir, RefreshTokenFile) if err := os.WriteFile(refreshTokenPath, []byte(token.RefreshToken), 0600); err != nil { // Don't fail if refresh token backup fails, just log fmt.Printf("Warning: Failed to save backup refresh token: %v\n", err) } } return nil } // LoadToken loads the authentication token from secure local storage func (am *AuthManager) LoadToken() (*types.AuthToken, error) { tokenPath := am.getTokenPath() data, err := os.ReadFile(tokenPath) if err != nil { return nil, fmt.Errorf("failed to read token file: %w", err) } var token types.AuthToken if err := json.Unmarshal(data, &token); err != nil { return nil, fmt.Errorf("failed to parse token: %w", err) } // Check if token is expired if time.Now().After(token.ExpiresAt.Add(-5 * time.Minute)) { return nil, fmt.Errorf("token is expired or expiring soon") } return &token, nil } // IsTokenExpired checks if a token needs refresh func (am *AuthManager) IsTokenExpired(token *types.AuthToken) bool { // Consider token expired if it expires within the next 5 minutes return time.Now().After(token.ExpiresAt.Add(-5 * time.Minute)) } // RegisterDevice performs the complete device registration flow func (am *AuthManager) RegisterDevice() (*types.AuthToken, error) { // Step 1: Start device authorization deviceAuth, err := am.StartDeviceAuthorization() if err != nil { return nil, fmt.Errorf("failed to start device authorization: %w", err) } fmt.Printf("Please visit: %s\n", deviceAuth.VerificationURI) fmt.Printf("And enter code: %s\n", deviceAuth.UserCode) // Step 2: Poll for token tokenResp, err := am.PollForToken(deviceAuth.DeviceCode) if err != nil { return nil, fmt.Errorf("failed to get token: %w", err) } // Step 3: Create token storage token := &types.AuthToken{ AccessToken: tokenResp.AccessToken, RefreshToken: tokenResp.RefreshToken, TokenType: tokenResp.TokenType, ExpiresAt: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second), AgentID: tokenResp.AgentID, } // Step 4: Save token if err := am.SaveToken(token); err != nil { return nil, fmt.Errorf("failed to save token: %w", err) } return token, nil } // EnsureAuthenticated ensures the agent has a valid token, refreshing if necessary func (am *AuthManager) EnsureAuthenticated() (*types.AuthToken, error) { // Try to load existing token token, err := am.LoadToken() if err == nil && !am.IsTokenExpired(token) { return token, nil } // Try to refresh with existing refresh token (even if access token is missing/expired) var refreshToken string if err == nil && token.RefreshToken != "" { // Use refresh token from loaded token refreshToken = token.RefreshToken } else { // Try to load refresh token from main token file even if load failed if existingToken, loadErr := am.loadTokenIgnoringExpiry(); loadErr == nil && existingToken.RefreshToken != "" { refreshToken = existingToken.RefreshToken } else { // Try to load refresh token from backup file if backupRefreshToken, backupErr := am.loadRefreshTokenFromBackup(); backupErr == nil { refreshToken = backupRefreshToken fmt.Println("šŸ”„ Found backup refresh token, attempting to use it...") } } } if refreshToken != "" { fmt.Println("šŸ”„ Attempting to refresh access token...") refreshResp, refreshErr := am.RefreshAccessToken(refreshToken) if refreshErr == nil { // Get existing agent_id from current token or backup var agentID string if err == nil && token.AgentID != "" { agentID = token.AgentID } else if existingToken, loadErr := am.loadTokenIgnoringExpiry(); loadErr == nil { agentID = existingToken.AgentID } // Create new token with refreshed values newToken := &types.AuthToken{ AccessToken: refreshResp.AccessToken, RefreshToken: refreshToken, // Keep existing refresh token TokenType: refreshResp.TokenType, ExpiresAt: time.Now().Add(time.Duration(refreshResp.ExpiresIn) * time.Second), AgentID: agentID, // Preserve agent_id } // Update refresh token if a new one was provided if refreshResp.RefreshToken != "" { newToken.RefreshToken = refreshResp.RefreshToken } if saveErr := am.SaveToken(newToken); saveErr == nil { return newToken, nil } } else { fmt.Printf("āš ļø Token refresh failed: %v\n", refreshErr) } } fmt.Println("šŸ“ Initiating new device registration...") return am.RegisterDevice() } // loadTokenIgnoringExpiry loads token file without checking expiry func (am *AuthManager) loadTokenIgnoringExpiry() (*types.AuthToken, error) { tokenPath := am.getTokenPath() data, err := os.ReadFile(tokenPath) if err != nil { return nil, fmt.Errorf("failed to read token file: %w", err) } var token types.AuthToken if err := json.Unmarshal(data, &token); err != nil { return nil, fmt.Errorf("failed to parse token: %w", err) } return &token, nil } // loadRefreshTokenFromBackup tries to load refresh token from backup file func (am *AuthManager) loadRefreshTokenFromBackup() (string, error) { refreshTokenPath := filepath.Join(TokenStorageDir, RefreshTokenFile) data, err := os.ReadFile(refreshTokenPath) if err != nil { return "", fmt.Errorf("failed to read refresh token backup: %w", err) } refreshToken := strings.TrimSpace(string(data)) if refreshToken == "" { return "", fmt.Errorf("refresh token backup is empty") } return refreshToken, nil } func (am *AuthManager) getTokenPath() string { if am.config.TokenPath != "" { return am.config.TokenPath } return filepath.Join(TokenStorageDir, TokenStorageFile) } func getHostname() string { if hostname, err := os.Hostname(); err == nil { return hostname } return "unknown" }