Integrate with supabase backend
This commit is contained in:
410
internal/auth/auth.go
Normal file
410
internal/auth/auth.go
Normal file
@@ -0,0 +1,410 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"nannyagentv2/internal/config"
|
||||
"nannyagentv2/internal/types"
|
||||
)
|
||||
|
||||
const (
|
||||
// Token storage location (secure directory)
|
||||
TokenStorageDir = "/var/lib/nannyagent"
|
||||
TokenStorageFile = ".agent_token.json"
|
||||
RefreshTokenFile = ".refresh_token"
|
||||
|
||||
// Polling configuration
|
||||
MaxPollAttempts = 60 // 5 minutes (60 * 5 seconds)
|
||||
PollInterval = 5 * time.Second
|
||||
)
|
||||
|
||||
// AuthManager handles all authentication-related operations
|
||||
type AuthManager struct {
|
||||
config *config.Config
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewAuthManager creates a new authentication manager
|
||||
func NewAuthManager(cfg *config.Config) *AuthManager {
|
||||
return &AuthManager{
|
||||
config: cfg,
|
||||
client: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// EnsureTokenStorageDir creates the token storage directory if it doesn't exist
|
||||
func (am *AuthManager) EnsureTokenStorageDir() error {
|
||||
// Check if running as root
|
||||
if os.Geteuid() != 0 {
|
||||
return fmt.Errorf("must run as root to create secure token storage directory")
|
||||
}
|
||||
|
||||
// Create directory with restricted permissions (0700 - only root can access)
|
||||
if err := os.MkdirAll(TokenStorageDir, 0700); err != nil {
|
||||
return fmt.Errorf("failed to create token storage directory: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartDeviceAuthorization initiates the OAuth device authorization flow
|
||||
func (am *AuthManager) StartDeviceAuthorization() (*types.DeviceAuthResponse, error) {
|
||||
payload := map[string]interface{}{
|
||||
"client_id": "nannyagent-cli",
|
||||
"scope": []string{"agent:register"},
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal payload: %w", err)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/device/authorize", am.config.DeviceAuthURL)
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := am.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to start device authorization: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("device authorization failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var deviceResp types.DeviceAuthResponse
|
||||
if err := json.Unmarshal(body, &deviceResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
return &deviceResp, nil
|
||||
}
|
||||
|
||||
// PollForToken polls the token endpoint until authorization is complete
|
||||
func (am *AuthManager) PollForToken(deviceCode string) (*types.TokenResponse, error) {
|
||||
fmt.Println("⏳ Waiting for user authorization...")
|
||||
|
||||
for attempts := 0; attempts < MaxPollAttempts; attempts++ {
|
||||
tokenReq := types.TokenRequest{
|
||||
GrantType: "urn:ietf:params:oauth:grant-type:device_code",
|
||||
DeviceCode: deviceCode,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(tokenReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal token request: %w", err)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/token", am.config.DeviceAuthURL)
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create token request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := am.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to poll for token: %w", err)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read token response: %w", err)
|
||||
}
|
||||
|
||||
var tokenResp types.TokenResponse
|
||||
if err := json.Unmarshal(body, &tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token response: %w", err)
|
||||
}
|
||||
|
||||
if tokenResp.Error != "" {
|
||||
if tokenResp.Error == "authorization_pending" {
|
||||
fmt.Print(".")
|
||||
time.Sleep(PollInterval)
|
||||
continue
|
||||
}
|
||||
return nil, fmt.Errorf("authorization failed: %s", tokenResp.ErrorDescription)
|
||||
}
|
||||
|
||||
if tokenResp.AccessToken != "" {
|
||||
fmt.Println("\n✅ Authorization successful!")
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
time.Sleep(PollInterval)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("authorization timed out after %d attempts", MaxPollAttempts)
|
||||
}
|
||||
|
||||
// RefreshAccessToken refreshes an expired access token using the refresh token
|
||||
func (am *AuthManager) RefreshAccessToken(refreshToken string) (*types.TokenResponse, error) {
|
||||
tokenReq := types.TokenRequest{
|
||||
GrantType: "refresh_token",
|
||||
RefreshToken: refreshToken,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(tokenReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal refresh request: %w", err)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/token", am.config.DeviceAuthURL)
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create refresh request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := am.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to refresh token: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read refresh response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("token refresh failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var tokenResp types.TokenResponse
|
||||
if err := json.Unmarshal(body, &tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse refresh response: %w", err)
|
||||
}
|
||||
|
||||
if tokenResp.Error != "" {
|
||||
return nil, fmt.Errorf("token refresh failed: %s", tokenResp.ErrorDescription)
|
||||
}
|
||||
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
// SaveToken saves the authentication token to secure local storage
|
||||
func (am *AuthManager) SaveToken(token *types.AuthToken) error {
|
||||
if err := am.EnsureTokenStorageDir(); err != nil {
|
||||
return fmt.Errorf("failed to ensure token storage directory: %w", err)
|
||||
}
|
||||
|
||||
// Save main token file
|
||||
tokenPath := am.getTokenPath()
|
||||
jsonData, err := json.MarshalIndent(token, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal token: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(tokenPath, jsonData, 0600); err != nil {
|
||||
return fmt.Errorf("failed to save token: %w", err)
|
||||
}
|
||||
|
||||
// Also save refresh token separately for backup recovery
|
||||
if token.RefreshToken != "" {
|
||||
refreshTokenPath := filepath.Join(TokenStorageDir, RefreshTokenFile)
|
||||
if err := os.WriteFile(refreshTokenPath, []byte(token.RefreshToken), 0600); err != nil {
|
||||
// Don't fail if refresh token backup fails, just log
|
||||
fmt.Printf("Warning: Failed to save backup refresh token: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
} // LoadToken loads the authentication token from secure local storage
|
||||
func (am *AuthManager) LoadToken() (*types.AuthToken, error) {
|
||||
tokenPath := am.getTokenPath()
|
||||
|
||||
data, err := os.ReadFile(tokenPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read token file: %w", err)
|
||||
}
|
||||
|
||||
var token types.AuthToken
|
||||
if err := json.Unmarshal(data, &token); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token: %w", err)
|
||||
}
|
||||
|
||||
// Check if token is expired
|
||||
if time.Now().After(token.ExpiresAt.Add(-5 * time.Minute)) {
|
||||
return nil, fmt.Errorf("token is expired or expiring soon")
|
||||
}
|
||||
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// IsTokenExpired checks if a token needs refresh
|
||||
func (am *AuthManager) IsTokenExpired(token *types.AuthToken) bool {
|
||||
// Consider token expired if it expires within the next 5 minutes
|
||||
return time.Now().After(token.ExpiresAt.Add(-5 * time.Minute))
|
||||
}
|
||||
|
||||
// RegisterDevice performs the complete device registration flow
|
||||
func (am *AuthManager) RegisterDevice() (*types.AuthToken, error) {
|
||||
// Step 1: Start device authorization
|
||||
deviceAuth, err := am.StartDeviceAuthorization()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to start device authorization: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Please visit: %s\n", deviceAuth.VerificationURI)
|
||||
fmt.Printf("And enter code: %s\n", deviceAuth.UserCode)
|
||||
|
||||
// Step 2: Poll for token
|
||||
tokenResp, err := am.PollForToken(deviceAuth.DeviceCode)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get token: %w", err)
|
||||
}
|
||||
|
||||
// Step 3: Create token storage
|
||||
token := &types.AuthToken{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
TokenType: tokenResp.TokenType,
|
||||
ExpiresAt: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second),
|
||||
AgentID: tokenResp.AgentID,
|
||||
}
|
||||
|
||||
// Step 4: Save token
|
||||
if err := am.SaveToken(token); err != nil {
|
||||
return nil, fmt.Errorf("failed to save token: %w", err)
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// EnsureAuthenticated ensures the agent has a valid token, refreshing if necessary
|
||||
func (am *AuthManager) EnsureAuthenticated() (*types.AuthToken, error) {
|
||||
// Try to load existing token
|
||||
token, err := am.LoadToken()
|
||||
if err == nil && !am.IsTokenExpired(token) {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// Try to refresh with existing refresh token (even if access token is missing/expired)
|
||||
var refreshToken string
|
||||
if err == nil && token.RefreshToken != "" {
|
||||
// Use refresh token from loaded token
|
||||
refreshToken = token.RefreshToken
|
||||
} else {
|
||||
// Try to load refresh token from main token file even if load failed
|
||||
if existingToken, loadErr := am.loadTokenIgnoringExpiry(); loadErr == nil && existingToken.RefreshToken != "" {
|
||||
refreshToken = existingToken.RefreshToken
|
||||
} else {
|
||||
// Try to load refresh token from backup file
|
||||
if backupRefreshToken, backupErr := am.loadRefreshTokenFromBackup(); backupErr == nil {
|
||||
refreshToken = backupRefreshToken
|
||||
fmt.Println("🔄 Found backup refresh token, attempting to use it...")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if refreshToken != "" {
|
||||
fmt.Println("🔄 Attempting to refresh access token...")
|
||||
|
||||
refreshResp, refreshErr := am.RefreshAccessToken(refreshToken)
|
||||
if refreshErr == nil {
|
||||
// Get existing agent_id from current token or backup
|
||||
var agentID string
|
||||
if err == nil && token.AgentID != "" {
|
||||
agentID = token.AgentID
|
||||
} else if existingToken, loadErr := am.loadTokenIgnoringExpiry(); loadErr == nil {
|
||||
agentID = existingToken.AgentID
|
||||
}
|
||||
|
||||
// Create new token with refreshed values
|
||||
newToken := &types.AuthToken{
|
||||
AccessToken: refreshResp.AccessToken,
|
||||
RefreshToken: refreshToken, // Keep existing refresh token
|
||||
TokenType: refreshResp.TokenType,
|
||||
ExpiresAt: time.Now().Add(time.Duration(refreshResp.ExpiresIn) * time.Second),
|
||||
AgentID: agentID, // Preserve agent_id
|
||||
}
|
||||
|
||||
// Update refresh token if a new one was provided
|
||||
if refreshResp.RefreshToken != "" {
|
||||
newToken.RefreshToken = refreshResp.RefreshToken
|
||||
}
|
||||
|
||||
if saveErr := am.SaveToken(newToken); saveErr == nil {
|
||||
return newToken, nil
|
||||
}
|
||||
} else {
|
||||
fmt.Printf("⚠️ Token refresh failed: %v\n", refreshErr)
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println("📝 Initiating new device registration...")
|
||||
return am.RegisterDevice()
|
||||
}
|
||||
|
||||
// loadTokenIgnoringExpiry loads token file without checking expiry
|
||||
func (am *AuthManager) loadTokenIgnoringExpiry() (*types.AuthToken, error) {
|
||||
tokenPath := am.getTokenPath()
|
||||
|
||||
data, err := os.ReadFile(tokenPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read token file: %w", err)
|
||||
}
|
||||
|
||||
var token types.AuthToken
|
||||
if err := json.Unmarshal(data, &token); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token: %w", err)
|
||||
}
|
||||
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// loadRefreshTokenFromBackup tries to load refresh token from backup file
|
||||
func (am *AuthManager) loadRefreshTokenFromBackup() (string, error) {
|
||||
refreshTokenPath := filepath.Join(TokenStorageDir, RefreshTokenFile)
|
||||
|
||||
data, err := os.ReadFile(refreshTokenPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read refresh token backup: %w", err)
|
||||
}
|
||||
|
||||
refreshToken := strings.TrimSpace(string(data))
|
||||
if refreshToken == "" {
|
||||
return "", fmt.Errorf("refresh token backup is empty")
|
||||
}
|
||||
|
||||
return refreshToken, nil
|
||||
}
|
||||
|
||||
func (am *AuthManager) getTokenPath() string {
|
||||
if am.config.TokenPath != "" {
|
||||
return am.config.TokenPath
|
||||
}
|
||||
return filepath.Join(TokenStorageDir, TokenStorageFile)
|
||||
}
|
||||
|
||||
func getHostname() string {
|
||||
if hostname, err := os.Hostname(); err == nil {
|
||||
return hostname
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
Reference in New Issue
Block a user