Files
nannyagent/internal/auth/auth.go
Harshavardhan Musanalli 794111cb44 somewhat working ebpf bpftrace
2025-11-08 20:42:07 +01:00

511 lines
15 KiB
Go

package auth
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"nannyagentv2/internal/config"
"nannyagentv2/internal/logging"
"nannyagentv2/internal/types"
)
const (
// Token storage location (secure directory)
TokenStorageDir = "/var/lib/nannyagent"
TokenStorageFile = ".agent_token.json"
RefreshTokenFile = ".refresh_token"
// Polling configuration
MaxPollAttempts = 60 // 5 minutes (60 * 5 seconds)
PollInterval = 5 * time.Second
)
// AuthManager handles all authentication-related operations
type AuthManager struct {
config *config.Config
client *http.Client
}
// NewAuthManager creates a new authentication manager
func NewAuthManager(cfg *config.Config) *AuthManager {
return &AuthManager{
config: cfg,
client: &http.Client{
Timeout: 30 * time.Second,
},
}
}
// EnsureTokenStorageDir creates the token storage directory if it doesn't exist
func (am *AuthManager) EnsureTokenStorageDir() error {
// Check if running as root
if os.Geteuid() != 0 {
return fmt.Errorf("must run as root to create secure token storage directory")
}
// Create directory with restricted permissions (0700 - only root can access)
if err := os.MkdirAll(TokenStorageDir, 0700); err != nil {
return fmt.Errorf("failed to create token storage directory: %w", err)
}
return nil
}
// StartDeviceAuthorization initiates the OAuth device authorization flow
func (am *AuthManager) StartDeviceAuthorization() (*types.DeviceAuthResponse, error) {
payload := map[string]interface{}{
"client_id": "nannyagent-cli",
"scope": []string{"agent:register"},
}
jsonData, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("failed to marshal payload: %w", err)
}
url := fmt.Sprintf("%s/device/authorize", am.config.DeviceAuthURL)
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := am.client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to start device authorization: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("device authorization failed with status %d: %s", resp.StatusCode, string(body))
}
var deviceResp types.DeviceAuthResponse
if err := json.Unmarshal(body, &deviceResp); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
return &deviceResp, nil
}
// PollForToken polls the token endpoint until authorization is complete
func (am *AuthManager) PollForToken(deviceCode string) (*types.TokenResponse, error) {
logging.Info("Waiting for user authorization...")
for attempts := 0; attempts < MaxPollAttempts; attempts++ {
tokenReq := types.TokenRequest{
GrantType: "urn:ietf:params:oauth:grant-type:device_code",
DeviceCode: deviceCode,
}
jsonData, err := json.Marshal(tokenReq)
if err != nil {
return nil, fmt.Errorf("failed to marshal token request: %w", err)
}
url := fmt.Sprintf("%s/token", am.config.DeviceAuthURL)
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create token request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := am.client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to poll for token: %w", err)
}
body, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return nil, fmt.Errorf("failed to read token response: %w", err)
}
var tokenResp types.TokenResponse
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, fmt.Errorf("failed to parse token response: %w", err)
}
if tokenResp.Error != "" {
if tokenResp.Error == "authorization_pending" {
fmt.Print(".")
time.Sleep(PollInterval)
continue
}
return nil, fmt.Errorf("authorization failed: %s", tokenResp.ErrorDescription)
}
if tokenResp.AccessToken != "" {
logging.Info("Authorization successful!")
return &tokenResp, nil
}
time.Sleep(PollInterval)
}
return nil, fmt.Errorf("authorization timed out after %d attempts", MaxPollAttempts)
}
// RefreshAccessToken refreshes an expired access token using the refresh token
func (am *AuthManager) RefreshAccessToken(refreshToken string) (*types.TokenResponse, error) {
tokenReq := types.TokenRequest{
GrantType: "refresh_token",
RefreshToken: refreshToken,
}
jsonData, err := json.Marshal(tokenReq)
if err != nil {
return nil, fmt.Errorf("failed to marshal refresh request: %w", err)
}
url := fmt.Sprintf("%s/token", am.config.DeviceAuthURL)
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create refresh request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := am.client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to refresh token: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read refresh response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("token refresh failed with status %d: %s", resp.StatusCode, string(body))
}
var tokenResp types.TokenResponse
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, fmt.Errorf("failed to parse refresh response: %w", err)
}
if tokenResp.Error != "" {
return nil, fmt.Errorf("token refresh failed: %s", tokenResp.ErrorDescription)
}
return &tokenResp, nil
}
// SaveToken saves the authentication token to secure local storage
func (am *AuthManager) SaveToken(token *types.AuthToken) error {
if err := am.EnsureTokenStorageDir(); err != nil {
return fmt.Errorf("failed to ensure token storage directory: %w", err)
}
// Save main token file
tokenPath := am.getTokenPath()
jsonData, err := json.MarshalIndent(token, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal token: %w", err)
}
if err := os.WriteFile(tokenPath, jsonData, 0600); err != nil {
return fmt.Errorf("failed to save token: %w", err)
}
// Also save refresh token separately for backup recovery
if token.RefreshToken != "" {
refreshTokenPath := filepath.Join(TokenStorageDir, RefreshTokenFile)
if err := os.WriteFile(refreshTokenPath, []byte(token.RefreshToken), 0600); err != nil {
// Don't fail if refresh token backup fails, just log
logging.Warning("Failed to save backup refresh token: %v", err)
}
}
return nil
} // LoadToken loads the authentication token from secure local storage
func (am *AuthManager) LoadToken() (*types.AuthToken, error) {
tokenPath := am.getTokenPath()
data, err := os.ReadFile(tokenPath)
if err != nil {
return nil, fmt.Errorf("failed to read token file: %w", err)
}
var token types.AuthToken
if err := json.Unmarshal(data, &token); err != nil {
return nil, fmt.Errorf("failed to parse token: %w", err)
}
// Check if token is expired
if time.Now().After(token.ExpiresAt.Add(-5 * time.Minute)) {
return nil, fmt.Errorf("token is expired or expiring soon")
}
return &token, nil
}
// IsTokenExpired checks if a token needs refresh
func (am *AuthManager) IsTokenExpired(token *types.AuthToken) bool {
// Consider token expired if it expires within the next 5 minutes
return time.Now().After(token.ExpiresAt.Add(-5 * time.Minute))
}
// RegisterDevice performs the complete device registration flow
func (am *AuthManager) RegisterDevice() (*types.AuthToken, error) {
// Step 1: Start device authorization
deviceAuth, err := am.StartDeviceAuthorization()
if err != nil {
return nil, fmt.Errorf("failed to start device authorization: %w", err)
}
logging.Info("Please visit: %s", deviceAuth.VerificationURI)
logging.Info("And enter code: %s", deviceAuth.UserCode)
// Step 2: Poll for token
tokenResp, err := am.PollForToken(deviceAuth.DeviceCode)
if err != nil {
return nil, fmt.Errorf("failed to get token: %w", err)
}
// Step 3: Create token storage
token := &types.AuthToken{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
TokenType: tokenResp.TokenType,
ExpiresAt: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second),
AgentID: tokenResp.AgentID,
}
// Step 4: Save token
if err := am.SaveToken(token); err != nil {
return nil, fmt.Errorf("failed to save token: %w", err)
}
return token, nil
}
// EnsureAuthenticated ensures the agent has a valid token, refreshing if necessary
func (am *AuthManager) EnsureAuthenticated() (*types.AuthToken, error) {
// Try to load existing token
token, err := am.LoadToken()
if err == nil && !am.IsTokenExpired(token) {
return token, nil
}
// Try to refresh with existing refresh token (even if access token is missing/expired)
var refreshToken string
if err == nil && token.RefreshToken != "" {
// Use refresh token from loaded token
refreshToken = token.RefreshToken
} else {
// Try to load refresh token from main token file even if load failed
if existingToken, loadErr := am.loadTokenIgnoringExpiry(); loadErr == nil && existingToken.RefreshToken != "" {
refreshToken = existingToken.RefreshToken
} else {
// Try to load refresh token from backup file
if backupRefreshToken, backupErr := am.loadRefreshTokenFromBackup(); backupErr == nil {
refreshToken = backupRefreshToken
logging.Debug("Found backup refresh token, attempting to use it...")
}
}
}
if refreshToken != "" {
logging.Debug("Attempting to refresh access token...")
refreshResp, refreshErr := am.RefreshAccessToken(refreshToken)
if refreshErr == nil {
// Get existing agent_id from current token or backup
var agentID string
if err == nil && token.AgentID != "" {
agentID = token.AgentID
} else if existingToken, loadErr := am.loadTokenIgnoringExpiry(); loadErr == nil {
agentID = existingToken.AgentID
}
// Create new token with refreshed values
newToken := &types.AuthToken{
AccessToken: refreshResp.AccessToken,
RefreshToken: refreshToken, // Keep existing refresh token
TokenType: refreshResp.TokenType,
ExpiresAt: time.Now().Add(time.Duration(refreshResp.ExpiresIn) * time.Second),
AgentID: agentID, // Preserve agent_id
}
// Update refresh token if a new one was provided
if refreshResp.RefreshToken != "" {
newToken.RefreshToken = refreshResp.RefreshToken
}
if saveErr := am.SaveToken(newToken); saveErr == nil {
return newToken, nil
}
} else {
fmt.Printf("⚠️ Token refresh failed: %v\n", refreshErr)
}
}
fmt.Println("📝 Initiating new device registration...")
return am.RegisterDevice()
}
// loadTokenIgnoringExpiry loads token file without checking expiry
func (am *AuthManager) loadTokenIgnoringExpiry() (*types.AuthToken, error) {
tokenPath := am.getTokenPath()
data, err := os.ReadFile(tokenPath)
if err != nil {
return nil, fmt.Errorf("failed to read token file: %w", err)
}
var token types.AuthToken
if err := json.Unmarshal(data, &token); err != nil {
return nil, fmt.Errorf("failed to parse token: %w", err)
}
return &token, nil
}
// loadRefreshTokenFromBackup tries to load refresh token from backup file
func (am *AuthManager) loadRefreshTokenFromBackup() (string, error) {
refreshTokenPath := filepath.Join(TokenStorageDir, RefreshTokenFile)
data, err := os.ReadFile(refreshTokenPath)
if err != nil {
return "", fmt.Errorf("failed to read refresh token backup: %w", err)
}
refreshToken := strings.TrimSpace(string(data))
if refreshToken == "" {
return "", fmt.Errorf("refresh token backup is empty")
}
return refreshToken, nil
}
// GetCurrentAgentID retrieves the agent ID from cache or JWT token
func (am *AuthManager) GetCurrentAgentID() (string, error) {
// First try to read from local cache
agentID, err := am.loadCachedAgentID()
if err == nil && agentID != "" {
return agentID, nil
}
// Cache miss - extract from JWT token and cache it
token, err := am.LoadToken()
if err != nil {
return "", fmt.Errorf("failed to load token: %w", err)
}
// Extract agent ID from JWT 'sub' field
agentID, err = am.extractAgentIDFromJWT(token.AccessToken)
if err != nil {
return "", fmt.Errorf("failed to extract agent ID from JWT: %w", err)
}
// Cache the agent ID for future use
if err := am.cacheAgentID(agentID); err != nil {
// Log warning but don't fail - we still have the agent ID
fmt.Printf("Warning: Failed to cache agent ID: %v\n", err)
}
return agentID, nil
}
// extractAgentIDFromJWT decodes the JWT token and extracts the agent ID from 'sub' field
func (am *AuthManager) extractAgentIDFromJWT(tokenString string) (string, error) {
// Basic JWT decoding without verification (since we trust Supabase)
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
return "", fmt.Errorf("invalid JWT token format")
}
// Decode the payload (second part)
payload := parts[1]
// Add padding if needed for base64 decoding
for len(payload)%4 != 0 {
payload += "="
}
decoded, err := base64.URLEncoding.DecodeString(payload)
if err != nil {
return "", fmt.Errorf("failed to decode JWT payload: %w", err)
}
// Parse JSON payload
var claims map[string]interface{}
if err := json.Unmarshal(decoded, &claims); err != nil {
return "", fmt.Errorf("failed to parse JWT claims: %w", err)
}
// The agent ID is in the 'sub' field (subject)
if agentID, ok := claims["sub"].(string); ok && agentID != "" {
return agentID, nil
}
return "", fmt.Errorf("agent ID (sub) not found in JWT claims")
}
// loadCachedAgentID reads the cached agent ID from local storage
func (am *AuthManager) loadCachedAgentID() (string, error) {
agentIDPath := filepath.Join(TokenStorageDir, "agent_id")
data, err := os.ReadFile(agentIDPath)
if err != nil {
return "", fmt.Errorf("failed to read cached agent ID: %w", err)
}
agentID := strings.TrimSpace(string(data))
if agentID == "" {
return "", fmt.Errorf("cached agent ID is empty")
}
return agentID, nil
}
// cacheAgentID stores the agent ID in local cache
func (am *AuthManager) cacheAgentID(agentID string) error {
// Ensure the directory exists
if err := am.EnsureTokenStorageDir(); err != nil {
return fmt.Errorf("failed to ensure storage directory: %w", err)
}
agentIDPath := filepath.Join(TokenStorageDir, "agent_id")
// Write agent ID to file with secure permissions
if err := os.WriteFile(agentIDPath, []byte(agentID), 0600); err != nil {
return fmt.Errorf("failed to write agent ID cache: %w", err)
}
return nil
}
func (am *AuthManager) getTokenPath() string {
if am.config.TokenPath != "" {
return am.config.TokenPath
}
return filepath.Join(TokenStorageDir, TokenStorageFile)
}
func getHostname() string {
if hostname, err := os.Hostname(); err == nil {
return hostname
}
return "unknown"
}