jebin2 commited on
Commit
1bd7131
·
1 Parent(s): 945c3a3

google sign in

Browse files
.gitignore CHANGED
@@ -47,3 +47,7 @@ build/
47
  # Private keys (keep in repo but be careful)
48
  PRIVATE_KEY.pem
49
  client_secret.json
 
 
 
 
 
47
  # Private keys (keep in repo but be careful)
48
  PRIVATE_KEY.pem
49
  client_secret.json
50
+ test.jpg
51
+
52
+ # Downloaded video files
53
+ downloads/
app.py CHANGED
@@ -12,7 +12,7 @@ from fastapi.middleware.cors import CORSMiddleware
12
  from fastapi.responses import JSONResponse
13
 
14
  from core.database import init_db
15
- from routers import auth, blink, general
16
  from services.drive_service import DriveService
17
 
18
  # Configure logging
@@ -46,8 +46,18 @@ async def lifespan(app: FastAPI):
46
 
47
  await init_db()
48
  logger.info("Database initialized successfully")
 
 
 
 
 
 
49
  yield
50
 
 
 
 
 
51
  # Shutdown: Upload DB to Drive
52
  logger.info("Shutdown: Uploading database to Google Drive...")
53
  drive_service.upload_db()
@@ -75,6 +85,7 @@ app.add_middleware(
75
  app.include_router(general.router)
76
  app.include_router(auth.router)
77
  app.include_router(blink.router)
 
78
 
79
 
80
  @app.exception_handler(Exception)
 
12
  from fastapi.responses import JSONResponse
13
 
14
  from core.database import init_db
15
+ from routers import auth, blink, general, gemini
16
  from services.drive_service import DriveService
17
 
18
  # Configure logging
 
46
 
47
  await init_db()
48
  logger.info("Database initialized successfully")
49
+
50
+ # Start background job worker
51
+ from services.job_worker import start_worker, stop_worker
52
+ await start_worker()
53
+ logger.info("Background job worker started")
54
+
55
  yield
56
 
57
+ # Stop background job worker
58
+ await stop_worker()
59
+ logger.info("Background job worker stopped")
60
+
61
  # Shutdown: Upload DB to Drive
62
  logger.info("Shutdown: Uploading database to Google Drive...")
63
  drive_service.upload_db()
 
85
  app.include_router(general.router)
86
  app.include_router(auth.router)
87
  app.include_router(blink.router)
88
+ app.include_router(gemini.router)
89
 
90
 
91
  @app.exception_handler(Exception)
core/models.py CHANGED
@@ -37,6 +37,7 @@ class BlinkData(Base):
37
  class User(Base):
38
  """
39
  User model for credit system.
 
40
  """
41
  __tablename__ = "users"
42
 
@@ -44,7 +45,16 @@ class User(Base):
44
  user_id = Column(String(50), unique=True, index=True, nullable=False) # Backend generated UUID
45
  temp_user_id = Column(String(50), index=True, nullable=True) # From frontend
46
  email = Column(String(255), unique=True, index=True, nullable=False)
47
- secret_key_hash = Column(String(255), nullable=False)
 
 
 
 
 
 
 
 
 
48
  credits = Column(Integer, default=100)
49
  created_at = Column(DateTime(timezone=True), server_default=func.now())
50
  updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
@@ -55,6 +65,7 @@ class User(Base):
55
  return f"<User(id={self.id}, email={self.email})>"
56
 
57
 
 
58
  class RateLimit(Base):
59
  """
60
  Rate limit tracking table.
@@ -83,3 +94,26 @@ class AuditLog(Base):
83
  status = Column(String(20), nullable=False)
84
  error_message = Column(Text, nullable=True)
85
  timestamp = Column(DateTime(timezone=True), server_default=func.now())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  class User(Base):
38
  """
39
  User model for credit system.
40
+ Supports both legacy secret key and Google OAuth authentication.
41
  """
42
  __tablename__ = "users"
43
 
 
45
  user_id = Column(String(50), unique=True, index=True, nullable=False) # Backend generated UUID
46
  temp_user_id = Column(String(50), index=True, nullable=True) # From frontend
47
  email = Column(String(255), unique=True, index=True, nullable=False)
48
+
49
+ # Google OAuth fields
50
+ google_id = Column(String(255), unique=True, index=True, nullable=True) # Google sub claim
51
+ name = Column(String(255), nullable=True) # Display name from Google
52
+ profile_picture = Column(Text, nullable=True) # Google profile picture URL
53
+
54
+ # Legacy field (kept for migration, nullable now)
55
+ secret_key_hash = Column(String(255), nullable=True)
56
+
57
+ # Credits and status
58
  credits = Column(Integer, default=100)
59
  created_at = Column(DateTime(timezone=True), server_default=func.now())
60
  updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
 
65
  return f"<User(id={self.id}, email={self.email})>"
66
 
67
 
68
+
69
  class RateLimit(Base):
70
  """
71
  Rate limit tracking table.
 
94
  status = Column(String(20), nullable=False)
95
  error_message = Column(Text, nullable=True)
96
  timestamp = Column(DateTime(timezone=True), server_default=func.now())
97
+
98
+
99
+ class GeminiJob(Base):
100
+ """
101
+ Generic job queue for Gemini operations (video, image, text).
102
+ """
103
+ __tablename__ = "gemini_jobs"
104
+
105
+ id = Column(Integer, primary_key=True, autoincrement=True, index=True)
106
+ job_id = Column(String(100), unique=True, index=True, nullable=False) # Our ID for client
107
+ user_id = Column(String(50), index=True, nullable=False) # User who requested
108
+ job_type = Column(String(20), index=True, nullable=False) # video, image, text, analyze
109
+ third_party_id = Column(String(255), nullable=True) # Gemini operation name (for video)
110
+ status = Column(String(20), default="queued", index=True) # queued, processing, completed, failed
111
+ input_data = Column(JSON, nullable=True) # Request details (prompt, settings, etc.)
112
+ output_data = Column(JSON, nullable=True) # Result (filename, text, etc.)
113
+ error_message = Column(Text, nullable=True)
114
+ created_at = Column(DateTime(timezone=True), server_default=func.now())
115
+ started_at = Column(DateTime(timezone=True), nullable=True)
116
+ completed_at = Column(DateTime(timezone=True), nullable=True)
117
+
118
+ def __repr__(self):
119
+ return f"<GeminiJob(job_id={self.job_id}, type={self.job_type}, status={self.status})>"
core/schemas.py CHANGED
@@ -1,12 +1,46 @@
1
  from pydantic import BaseModel, EmailStr, Field
 
 
 
2
 
3
- # Pydantic Models
4
  class CheckRegistrationRequest(BaseModel):
 
5
  user_id: str = Field(..., min_length=1, description="Temporary user ID from frontend")
6
 
7
- class RegisterRequest(BaseModel):
8
- user_id: str = Field(..., min_length=1, description="Temporary user ID from frontend")
9
- email: EmailStr = Field(..., description="User email address")
10
 
11
- class ResetRequest(BaseModel):
12
- email: EmailStr = Field(..., description="User email address")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from pydantic import BaseModel, EmailStr, Field
2
+ from typing import Optional
3
+
4
+ # Pydantic Models for Google OAuth Authentication
5
 
 
6
  class CheckRegistrationRequest(BaseModel):
7
+ """Check if a user_id has completed registration."""
8
  user_id: str = Field(..., min_length=1, description="Temporary user ID from frontend")
9
 
 
 
 
10
 
11
+ class GoogleAuthRequest(BaseModel):
12
+ """Request with Google ID token from frontend Sign-In."""
13
+ id_token: str = Field(..., min_length=1, description="Google ID token from Sign-In")
14
+ temp_user_id: Optional[str] = Field(None, description="Optional temp user ID for linking")
15
+
16
+
17
+ class AuthResponse(BaseModel):
18
+ """Response after successful Google authentication."""
19
+ success: bool
20
+ access_token: str
21
+ user_id: str
22
+ email: str
23
+ name: Optional[str] = None
24
+ credits: int
25
+ is_new_user: bool
26
+
27
+
28
+ class UserInfoResponse(BaseModel):
29
+ """Response containing current user information."""
30
+ user_id: str
31
+ email: str
32
+ name: Optional[str] = None
33
+ credits: int
34
+ profile_picture: Optional[str] = None
35
+
36
+
37
+ class TokenRefreshRequest(BaseModel):
38
+ """Request to refresh an access token."""
39
+ token: str = Field(..., description="Current access token to refresh")
40
+
41
+
42
+ class TokenRefreshResponse(BaseModel):
43
+ """Response with refreshed access token."""
44
+ success: bool
45
+ access_token: str
46
+
core/security.py CHANGED
@@ -1,25 +1,32 @@
1
- import secrets
 
 
 
 
 
2
  import bcrypt
3
 
 
4
  def verify_password(plain_password: str, hashed_password: str) -> bool:
5
  """
6
- Verify a password against a hash.
 
 
 
7
  """
8
  if isinstance(hashed_password, str):
9
  hashed_password = hashed_password.encode('utf-8')
10
  return bcrypt.checkpw(plain_password.encode('utf-8'), hashed_password)
11
 
 
12
  def get_password_hash(password: str) -> str:
13
  """
14
  Hash a password using bcrypt.
 
 
 
15
  """
16
- # rounds=12 as per spec
17
  salt = bcrypt.gensalt(rounds=12)
18
  hashed = bcrypt.hashpw(password.encode('utf-8'), salt)
19
  return hashed.decode('utf-8')
20
 
21
- def generate_secret_key() -> str:
22
- """
23
- Generate a secure secret key starting with 'sk_'.
24
- """
25
- return "sk_" + secrets.token_urlsafe(32)
 
1
+ """
2
+ Core Security Utilities
3
+
4
+ Note: Secret key authentication has been replaced with Google OAuth.
5
+ The bcrypt functions below are kept for potential future use (e.g., admin passwords).
6
+ """
7
  import bcrypt
8
 
9
+
10
  def verify_password(plain_password: str, hashed_password: str) -> bool:
11
  """
12
+ Verify a password against a bcrypt hash.
13
+
14
+ Note: This is no longer used for user authentication (moved to Google OAuth).
15
+ Kept for potential admin/internal use cases.
16
  """
17
  if isinstance(hashed_password, str):
18
  hashed_password = hashed_password.encode('utf-8')
19
  return bcrypt.checkpw(plain_password.encode('utf-8'), hashed_password)
20
 
21
+
22
  def get_password_hash(password: str) -> str:
23
  """
24
  Hash a password using bcrypt.
25
+
26
+ Note: This is no longer used for user authentication (moved to Google OAuth).
27
+ Kept for potential admin/internal use cases.
28
  """
 
29
  salt = bcrypt.gensalt(rounds=12)
30
  hashed = bcrypt.hashpw(password.encode('utf-8'), salt)
31
  return hashed.decode('utf-8')
32
 
 
 
 
 
 
dependencies.py CHANGED
@@ -9,7 +9,12 @@ from sqlalchemy.ext.asyncio import AsyncSession
9
 
10
  from core.database import get_db
11
  from core.models import User, RateLimit
12
- from core.security import verify_password
 
 
 
 
 
13
 
14
  logger = logging.getLogger(__name__)
15
 
@@ -63,57 +68,107 @@ async def check_rate_limit(
63
  await db.commit()
64
  return True
65
 
66
- async def verify_credits(
 
67
  req: Request,
68
  db: AsyncSession = Depends(get_db)
69
  ) -> User:
70
  """
71
- Dependency to validate secret key and deduct credits.
 
 
 
 
 
 
72
  """
73
- secret_key = req.headers.get("X-Secret-Key")
74
- if not secret_key:
 
75
  raise HTTPException(
76
  status_code=status.HTTP_401_UNAUTHORIZED,
77
- detail="Missing X-Secret-Key header"
 
78
  )
79
 
80
- # Validate secret key format
81
- if not secret_key.startswith("sk_"):
82
  raise HTTPException(
83
  status_code=status.HTTP_401_UNAUTHORIZED,
84
- detail="Invalid secret key format"
 
85
  )
86
-
87
- # Find user
88
- query = select(User).where(User.is_active == True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  result = await db.execute(query)
90
- users = result.scalars().all()
91
-
92
- valid_user = None
93
- for user in users:
94
- if verify_password(secret_key, user.secret_key_hash):
95
- valid_user = user
96
- break
97
-
98
- if not valid_user:
99
  raise HTTPException(
100
  status_code=status.HTTP_401_UNAUTHORIZED,
101
- detail="Invalid secret key"
102
  )
103
-
104
- # Check credits
105
- if valid_user.credits <= 0:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  raise HTTPException(
107
  status_code=status.HTTP_402_PAYMENT_REQUIRED,
108
- detail="Insufficient credits"
109
  )
110
-
111
  # Deduct credit
112
- valid_user.credits -= 1
113
- valid_user.last_used_at = datetime.utcnow()
114
  await db.commit()
115
 
116
- return valid_user
 
 
 
117
 
118
  async def get_geolocation(ip_address: str) -> Tuple[Optional[str], Optional[str]]:
119
  """
 
9
 
10
  from core.database import get_db
11
  from core.models import User, RateLimit
12
+ from services.jwt_service import (
13
+ verify_access_token,
14
+ TokenExpiredError,
15
+ InvalidTokenError,
16
+ JWTError
17
+ )
18
 
19
  logger = logging.getLogger(__name__)
20
 
 
68
  await db.commit()
69
  return True
70
 
71
+
72
+ async def get_current_user(
73
  req: Request,
74
  db: AsyncSession = Depends(get_db)
75
  ) -> User:
76
  """
77
+ Extract and verify JWT from Authorization header.
78
+ Returns the authenticated user.
79
+
80
+ Usage:
81
+ @router.get("/protected")
82
+ async def protected_route(user: User = Depends(get_current_user)):
83
+ return {"user_id": user.user_id}
84
  """
85
+ auth_header = req.headers.get("Authorization")
86
+
87
+ if not auth_header:
88
  raise HTTPException(
89
  status_code=status.HTTP_401_UNAUTHORIZED,
90
+ detail="Missing Authorization header",
91
+ headers={"WWW-Authenticate": "Bearer"}
92
  )
93
 
94
+ if not auth_header.startswith("Bearer "):
 
95
  raise HTTPException(
96
  status_code=status.HTTP_401_UNAUTHORIZED,
97
+ detail="Invalid Authorization header format. Use: Bearer <token>",
98
+ headers={"WWW-Authenticate": "Bearer"}
99
  )
100
+
101
+ token = auth_header.split(" ", 1)[1]
102
+
103
+ try:
104
+ payload = verify_access_token(token)
105
+ except TokenExpiredError:
106
+ raise HTTPException(
107
+ status_code=status.HTTP_401_UNAUTHORIZED,
108
+ detail="Token has expired. Please sign in again.",
109
+ headers={"WWW-Authenticate": "Bearer"}
110
+ )
111
+ except InvalidTokenError as e:
112
+ raise HTTPException(
113
+ status_code=status.HTTP_401_UNAUTHORIZED,
114
+ detail=f"Invalid token: {str(e)}",
115
+ headers={"WWW-Authenticate": "Bearer"}
116
+ )
117
+ except JWTError as e:
118
+ raise HTTPException(
119
+ status_code=status.HTTP_401_UNAUTHORIZED,
120
+ detail=f"Authentication error: {str(e)}",
121
+ headers={"WWW-Authenticate": "Bearer"}
122
+ )
123
+
124
+ # Get user from DB
125
+ query = select(User).where(
126
+ User.user_id == payload.user_id,
127
+ User.is_active == True
128
+ )
129
  result = await db.execute(query)
130
+ user = result.scalar_one_or_none()
131
+
132
+ if not user:
 
 
 
 
 
 
133
  raise HTTPException(
134
  status_code=status.HTTP_401_UNAUTHORIZED,
135
+ detail="User not found or inactive"
136
  )
137
+
138
+ return user
139
+
140
+
141
+ async def verify_credits(
142
+ user: User = Depends(get_current_user),
143
+ db: AsyncSession = Depends(get_db)
144
+ ) -> User:
145
+ """
146
+ Verify user has credits and deduct one.
147
+
148
+ This dependency first authenticates the user via JWT,
149
+ then checks and deducts credits.
150
+
151
+ Usage:
152
+ @router.post("/api-endpoint")
153
+ async def api_endpoint(user: User = Depends(verify_credits)):
154
+ # User is authenticated and has 1 credit deducted
155
+ return {"credits_remaining": user.credits}
156
+ """
157
+ if user.credits <= 0:
158
  raise HTTPException(
159
  status_code=status.HTTP_402_PAYMENT_REQUIRED,
160
+ detail="Insufficient credits. Please purchase more credits."
161
  )
162
+
163
  # Deduct credit
164
+ user.credits -= 1
165
+ user.last_used_at = datetime.utcnow()
166
  await db.commit()
167
 
168
+ logger.debug(f"Deducted 1 credit from user {user.user_id}. Remaining: {user.credits}")
169
+
170
+ return user
171
+
172
 
173
  async def get_geolocation(ip_address: str) -> Tuple[Optional[str], Optional[str]]:
174
  """
docs/CLIENT_INTEGRATION.md ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Google Sign-In Client Integration Guide
2
+
3
+ Quick guide for frontend developers to integrate Google Sign-In with the APIGateway.
4
+
5
+ ## Setup
6
+
7
+ ### 1. Get Your Google Client ID
8
+
9
+ Use the same `GOOGLE_CLIENT_ID` that's configured on the backend.
10
+
11
+ ### 2. Add Google Identity Services Script
12
+
13
+ ```html
14
+ <script src="https://accounts.google.com/gsi/client" async defer></script>
15
+ ```
16
+
17
+ ---
18
+
19
+ ## Option A: One-Tap / Button Sign-In (Recommended)
20
+
21
+ ```html
22
+ <!DOCTYPE html>
23
+ <html>
24
+ <head>
25
+ <script src="https://accounts.google.com/gsi/client" async defer></script>
26
+ </head>
27
+ <body>
28
+ <!-- Google One Tap prompt -->
29
+ <div id="g_id_onload"
30
+ data-client_id="YOUR_GOOGLE_CLIENT_ID.apps.googleusercontent.com"
31
+ data-callback="handleGoogleSignIn"
32
+ data-auto_prompt="false">
33
+ </div>
34
+
35
+ <!-- Sign In Button -->
36
+ <div class="g_id_signin"
37
+ data-type="standard"
38
+ data-size="large"
39
+ data-theme="outline"
40
+ data-text="sign_in_with"
41
+ data-shape="rectangular"
42
+ data-logo_alignment="left">
43
+ </div>
44
+
45
+ <script>
46
+ const API_BASE = 'https://your-api-gateway.com'; // Change this
47
+
48
+ async function handleGoogleSignIn(response) {
49
+ try {
50
+ // Send Google token to your backend
51
+ const res = await fetch(`${API_BASE}/auth/google`, {
52
+ method: 'POST',
53
+ headers: { 'Content-Type': 'application/json' },
54
+ body: JSON.stringify({
55
+ id_token: response.credential,
56
+ temp_user_id: localStorage.getItem('temp_user_id') // optional
57
+ })
58
+ });
59
+
60
+ const data = await res.json();
61
+
62
+ if (data.success) {
63
+ // Store the access token
64
+ localStorage.setItem('access_token', data.access_token);
65
+ localStorage.setItem('user', JSON.stringify({
66
+ user_id: data.user_id,
67
+ email: data.email,
68
+ name: data.name,
69
+ credits: data.credits
70
+ }));
71
+
72
+ console.log('Signed in!', data.is_new_user ? 'New user' : 'Existing user');
73
+ // Redirect or update UI
74
+ window.location.reload();
75
+ } else {
76
+ alert('Sign in failed: ' + (data.detail || 'Unknown error'));
77
+ }
78
+ } catch (error) {
79
+ console.error('Sign in error:', error);
80
+ alert('Sign in failed. Please try again.');
81
+ }
82
+ }
83
+ </script>
84
+ </body>
85
+ </html>
86
+ ```
87
+
88
+ ---
89
+
90
+ ## Option B: Programmatic Sign-In
91
+
92
+ ```javascript
93
+ const API_BASE = 'https://your-api-gateway.com';
94
+
95
+ // Initialize Google Sign-In
96
+ function initGoogleSignIn(clientId) {
97
+ google.accounts.id.initialize({
98
+ client_id: clientId,
99
+ callback: handleGoogleSignIn
100
+ });
101
+
102
+ // Render button in a container
103
+ google.accounts.id.renderButton(
104
+ document.getElementById('google-signin-btn'),
105
+ { theme: 'outline', size: 'large', text: 'signin_with' }
106
+ );
107
+ }
108
+
109
+ // Handle the response
110
+ async function handleGoogleSignIn(response) {
111
+ const res = await fetch(`${API_BASE}/auth/google`, {
112
+ method: 'POST',
113
+ headers: { 'Content-Type': 'application/json' },
114
+ body: JSON.stringify({ id_token: response.credential })
115
+ });
116
+
117
+ const data = await res.json();
118
+ if (data.success) {
119
+ localStorage.setItem('access_token', data.access_token);
120
+ }
121
+ }
122
+
123
+ // Call on page load
124
+ window.onload = () => initGoogleSignIn('YOUR_CLIENT_ID.apps.googleusercontent.com');
125
+ ```
126
+
127
+ ---
128
+
129
+ ## Making API Calls
130
+
131
+ After sign-in, use the access token for all API calls:
132
+
133
+ ```javascript
134
+ const API_BASE = 'https://your-api-gateway.com';
135
+
136
+ async function apiCall(endpoint, options = {}) {
137
+ const token = localStorage.getItem('access_token');
138
+
139
+ if (!token) {
140
+ throw new Error('Not signed in');
141
+ }
142
+
143
+ const response = await fetch(`${API_BASE}${endpoint}`, {
144
+ ...options,
145
+ headers: {
146
+ 'Authorization': `Bearer ${token}`,
147
+ 'Content-Type': 'application/json',
148
+ ...options.headers
149
+ }
150
+ });
151
+
152
+ if (response.status === 401) {
153
+ // Token expired - need to sign in again
154
+ localStorage.removeItem('access_token');
155
+ window.location.href = '/login';
156
+ return;
157
+ }
158
+
159
+ return response.json();
160
+ }
161
+
162
+ // Examples
163
+ async function getCurrentUser() {
164
+ return apiCall('/auth/me');
165
+ }
166
+
167
+ async function generateText(prompt) {
168
+ return apiCall('/gemini/generate-text', {
169
+ method: 'POST',
170
+ body: JSON.stringify({ prompt })
171
+ });
172
+ }
173
+
174
+ async function checkJobStatus(jobId) {
175
+ return apiCall(`/gemini/job/${jobId}`);
176
+ }
177
+ ```
178
+
179
+ ---
180
+
181
+ ## API Endpoints Reference
182
+
183
+ | Endpoint | Method | Auth Required | Description |
184
+ |----------|--------|---------------|-------------|
185
+ | `/auth/google` | POST | ❌ | Sign in with Google token |
186
+ | `/auth/me` | GET | ✅ | Get current user info |
187
+ | `/auth/refresh` | POST | ❌ | Refresh access token |
188
+ | `/gemini/generate-text` | POST | ✅ | Generate text (costs 1 credit) |
189
+ | `/gemini/generate-video` | POST | ✅ | Generate video (costs 1 credit) |
190
+ | `/gemini/job/{job_id}` | GET | ✅ | Check job status |
191
+
192
+ ---
193
+
194
+ ## Sign Out
195
+
196
+ ```javascript
197
+ function signOut() {
198
+ localStorage.removeItem('access_token');
199
+ localStorage.removeItem('user');
200
+
201
+ // Optionally call logout endpoint for audit
202
+ fetch(`${API_BASE}/auth/logout`, {
203
+ method: 'POST',
204
+ headers: { 'Authorization': `Bearer ${token}` }
205
+ });
206
+
207
+ // Revoke Google session
208
+ google.accounts.id.disableAutoSelect();
209
+
210
+ window.location.href = '/';
211
+ }
212
+ ```
213
+
214
+ ---
215
+
216
+ ## Error Handling
217
+
218
+ | Status | Meaning | Action |
219
+ |--------|---------|--------|
220
+ | 401 | Token expired/invalid | Redirect to sign-in |
221
+ | 402 | Insufficient credits | Show "buy credits" prompt |
222
+ | 429 | Rate limited | Wait and retry |
223
+
224
+ ```javascript
225
+ async function apiCallWithErrorHandling(endpoint, options) {
226
+ const response = await apiCall(endpoint, options);
227
+
228
+ if (response.status === 402) {
229
+ alert('You have run out of credits!');
230
+ return null;
231
+ }
232
+
233
+ return response;
234
+ }
235
+ ```
236
+
237
+ ---
238
+
239
+ ## React Example
240
+
241
+ ```jsx
242
+ import { useEffect, useState } from 'react';
243
+
244
+ function App() {
245
+ const [user, setUser] = useState(null);
246
+
247
+ useEffect(() => {
248
+ // Check if already signed in
249
+ const token = localStorage.getItem('access_token');
250
+ if (token) {
251
+ fetch('/auth/me', {
252
+ headers: { 'Authorization': `Bearer ${token}` }
253
+ })
254
+ .then(r => r.json())
255
+ .then(setUser)
256
+ .catch(() => localStorage.removeItem('access_token'));
257
+ }
258
+
259
+ // Initialize Google Sign-In
260
+ window.handleGoogleSignIn = async (response) => {
261
+ const res = await fetch('/auth/google', {
262
+ method: 'POST',
263
+ headers: { 'Content-Type': 'application/json' },
264
+ body: JSON.stringify({ id_token: response.credential })
265
+ });
266
+ const data = await res.json();
267
+ if (data.success) {
268
+ localStorage.setItem('access_token', data.access_token);
269
+ setUser(data);
270
+ }
271
+ };
272
+ }, []);
273
+
274
+ if (!user) {
275
+ return (
276
+ <div>
277
+ <div id="g_id_onload"
278
+ data-client_id="YOUR_CLIENT_ID"
279
+ data-callback="handleGoogleSignIn" />
280
+ <div className="g_id_signin" data-type="standard" />
281
+ </div>
282
+ );
283
+ }
284
+
285
+ return <div>Welcome, {user.name}! Credits: {user.credits}</div>;
286
+ }
287
+ ```
generate_jwt_secret.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Generate JWT Secret Key
4
+
5
+ This script generates a cryptographically secure secret key for JWT signing.
6
+ Run this locally and add the generated key to your .env file.
7
+
8
+ Usage:
9
+ python generate_jwt_secret.py
10
+
11
+ # Or with custom length
12
+ python generate_jwt_secret.py --length 128
13
+
14
+ Output:
15
+ Prints the secret key and instructions for adding it to your environment.
16
+ """
17
+
18
+ import argparse
19
+ import secrets
20
+ import sys
21
+
22
+
23
+ def generate_secret(length: int = 64) -> str:
24
+ """
25
+ Generate a cryptographically secure URL-safe secret.
26
+
27
+ Args:
28
+ length: Number of bytes for the secret (default: 64).
29
+ The actual string length will be ~1.3x this due to base64 encoding.
30
+
31
+ Returns:
32
+ str: URL-safe base64 encoded secret.
33
+ """
34
+ return secrets.token_urlsafe(length)
35
+
36
+
37
+ def main():
38
+ parser = argparse.ArgumentParser(
39
+ description="Generate a secure JWT secret key",
40
+ formatter_class=argparse.RawDescriptionHelpFormatter,
41
+ epilog="""
42
+ Examples:
43
+ python generate_jwt_secret.py
44
+ python generate_jwt_secret.py --length 128
45
+ python generate_jwt_secret.py --format docker
46
+ """
47
+ )
48
+ parser.add_argument(
49
+ "--length", "-l",
50
+ type=int,
51
+ default=64,
52
+ help="Number of bytes for the secret (default: 64)"
53
+ )
54
+ parser.add_argument(
55
+ "--format", "-f",
56
+ choices=["env", "docker", "export", "raw"],
57
+ default="env",
58
+ help="Output format (default: env)"
59
+ )
60
+
61
+ args = parser.parse_args()
62
+
63
+ if args.length < 32:
64
+ print("Warning: Secret length should be at least 32 bytes for security.", file=sys.stderr)
65
+
66
+ secret = generate_secret(args.length)
67
+
68
+ print("\n" + "=" * 60)
69
+ print("🔐 Generated JWT Secret Key")
70
+ print("=" * 60)
71
+
72
+ if args.format == "raw":
73
+ print(secret)
74
+ elif args.format == "env":
75
+ print(f"\nAdd this line to your .env file:\n")
76
+ print(f"JWT_SECRET={secret}")
77
+ elif args.format == "docker":
78
+ print(f"\nAdd this to your docker-compose.yml environment:\n")
79
+ print(f" - JWT_SECRET={secret}")
80
+ elif args.format == "export":
81
+ print(f"\nRun this command to set the environment variable:\n")
82
+ print(f"export JWT_SECRET='{secret}'")
83
+
84
+ print("\n" + "-" * 60)
85
+ print("⚠️ IMPORTANT SECURITY NOTES:")
86
+ print("-" * 60)
87
+ print("• Keep this secret confidential - never commit it to git")
88
+ print("• Use different secrets for development and production")
89
+ print("• If compromised, all existing tokens become invalid")
90
+ print("• Store securely (e.g., secrets manager, encrypted env)")
91
+ print("=" * 60 + "\n")
92
+
93
+
94
+ if __name__ == "__main__":
95
+ main()
requirements.txt CHANGED
@@ -13,3 +13,6 @@ python-dotenv>=1.0.0
13
  google-api-python-client>=2.0.0
14
  google-auth-oauthlib>=1.0.0
15
  google-auth-httplib2>=0.1.0
 
 
 
 
13
  google-api-python-client>=2.0.0
14
  google-auth-oauthlib>=1.0.0
15
  google-auth-httplib2>=0.1.0
16
+ google-genai>=1.0.0
17
+ PyJWT>=2.8.0
18
+
routers/auth.py CHANGED
@@ -1,21 +1,48 @@
1
- from fastapi import APIRouter, Depends, HTTPException, status, Request, BackgroundTasks, Header
 
 
 
 
 
 
2
  from fastapi.responses import JSONResponse
3
  from sqlalchemy.ext.asyncio import AsyncSession
4
  from sqlalchemy import select
5
  from datetime import datetime
6
  import uuid
 
7
 
8
  from core.database import get_db
9
  from core.models import User, AuditLog
10
- from core.schemas import CheckRegistrationRequest, RegisterRequest, ResetRequest
11
- from core.security import get_password_hash, verify_password, generate_secret_key
12
- from services.email_service import send_email
13
- from dependencies import check_rate_limit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  from services.drive_service import DriveService
15
 
 
 
16
  router = APIRouter(prefix="/auth", tags=["auth"])
17
  drive_service = DriveService()
18
 
 
19
  @router.post("/check-registration")
20
  async def check_registration(
21
  request: CheckRegistrationRequest,
@@ -24,6 +51,7 @@ async def check_registration(
24
  ):
25
  """
26
  Check if a temporary user_id has completed registration.
 
27
  """
28
  # Rate Limit: 10 requests per minute per IP
29
  ip = req.client.host
@@ -37,227 +65,209 @@ async def check_registration(
37
  return {"is_registered": user is not None}
38
 
39
 
40
- @router.post("/register")
41
- async def register(
42
- request: RegisterRequest,
43
  req: Request,
44
  background_tasks: BackgroundTasks,
45
  db: AsyncSession = Depends(get_db)
46
  ):
47
  """
48
- Register new user, generate secret key, send email.
 
 
 
 
 
 
 
 
 
49
  """
50
- # Rate Limit: 5 registrations per hour per IP
51
  ip = req.client.host
52
- if not await check_rate_limit(db, ip, "/auth/register", 5, 60):
53
- raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail="Too many registration attempts")
54
-
55
- # Check Email Already Registered
56
- query = select(User).where(User.email == request.email)
57
- result = await db.execute(query)
58
- if result.scalar_one_or_none():
59
- raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Email already registered")
60
-
61
- # Check temp_user_id Already Registered
62
- query = select(User).where(User.temp_user_id == request.user_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  result = await db.execute(query)
64
- if result.scalar_one_or_none():
65
- raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="User already registered")
66
-
67
- # Generate Secret Key
68
- secret_key = generate_secret_key()
69
- secret_key_hash = get_password_hash(secret_key)
70
- backend_user_id = "usr_" + str(uuid.uuid4())
71
-
72
- # Create User
73
- new_user = User(
74
- user_id=backend_user_id,
75
- temp_user_id=request.user_id,
76
- email=request.email,
77
- secret_key_hash=secret_key_hash,
78
- credits=100
79
- )
80
- db.add(new_user)
81
 
82
- # Log Audit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  audit_log = AuditLog(
84
- user_id=backend_user_id,
85
- action="register",
86
  ip_address=ip,
87
  status="success"
88
  )
89
  db.add(audit_log)
90
-
91
  await db.commit()
92
-
93
- # Send Email (Async)
94
- email_body = f"""Hello,
95
-
96
- Your secret key is: {secret_key}
97
-
98
- Please save this key securely.
99
- You'll need it to access your credits.
100
-
101
- If you lose this key, use the 'Forgot Key'
102
- option with this email address.
103
-
104
- Credits: 100
105
- Valid from: {datetime.now().strftime('%Y-%m-%d')}
106
-
107
- Do not share this key with anyone."""
108
-
109
- background_tasks.add_task(
110
- send_email,
111
- request.email,
112
- "Your Secret Key - Credit System",
113
- email_body
114
- )
115
 
116
  # Sync DB to Drive (Async)
117
  background_tasks.add_task(drive_service.upload_db)
 
 
 
 
 
 
 
 
 
 
118
 
119
- return {"success": True, "message": "Registration successful. Check your email."}
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
- @router.post("/validate")
123
- async def validate_key(
 
 
124
  req: Request,
125
- background_tasks: BackgroundTasks,
126
- x_secret_key: str = Header(..., alias="X-Secret-Key"),
127
  db: AsyncSession = Depends(get_db)
128
  ):
129
  """
130
- Validate secret key and return user info.
 
 
 
 
131
  """
132
- # Rate Limit: 20 validations per hour per IP
133
  ip = req.client.host
134
- if not await check_rate_limit(db, ip, "/auth/validate", 20, 60):
135
- raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail="Too many validation attempts")
136
-
137
- query = select(User).where(User.is_active == True)
138
- result = await db.execute(query)
139
- users = result.scalars().all()
140
-
141
- valid_user = None
142
- for user in users:
143
- if verify_password(x_secret_key, user.secret_key_hash):
144
- valid_user = user
145
- break
146
 
147
- if valid_user:
148
- # Update last_used_at
149
- valid_user.last_used_at = datetime.utcnow()
150
-
151
- # Log Audit
152
- audit_log = AuditLog(
153
- user_id=valid_user.user_id,
154
- action="validate",
155
- ip_address=ip,
156
- status="success"
157
  )
158
- db.add(audit_log)
159
- await db.commit()
160
-
161
- # Sync DB to Drive (Async) - Optional but good for audit logs
162
- background_tasks.add_task(drive_service.upload_db)
163
 
164
- return {
165
- "valid": True,
166
- "user_id": valid_user.user_id,
167
- "credits": valid_user.credits,
168
- "message": "Valid key"
169
- }
170
- else:
171
- # Log Audit
172
- audit_log = AuditLog(
173
- user_id=None,
174
- action="validate",
175
- ip_address=ip,
176
- status="failed",
177
- error_message="Invalid key"
178
  )
179
- db.add(audit_log)
180
- await db.commit()
181
-
182
- return JSONResponse(
183
  status_code=status.HTTP_401_UNAUTHORIZED,
184
- content={"valid": False, "message": "Invalid secret key"}
185
  )
186
 
187
 
188
- @router.post("/reset")
189
- async def reset_key(
190
- request: ResetRequest,
191
  req: Request,
192
  background_tasks: BackgroundTasks,
 
193
  db: AsyncSession = Depends(get_db)
194
  ):
195
  """
196
- Reset/recover secret key via email.
197
- """
198
- # Rate Limit: 3 reset attempts per hour per email
199
- if not await check_rate_limit(db, request.email, "/auth/reset", 3, 60):
200
- raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail="Too many reset attempts")
201
-
202
- query = select(User).where(User.email == request.email)
203
- result = await db.execute(query)
204
- user = result.scalar_one_or_none()
205
 
 
 
 
206
  ip = req.client.host
207
-
208
- if user:
209
- # Generate New Secret Key
210
- new_secret_key = generate_secret_key()
211
- new_secret_key_hash = get_password_hash(new_secret_key)
212
-
213
- user.secret_key_hash = new_secret_key_hash
214
- user.updated_at = datetime.utcnow()
215
-
216
- # Log Audit
217
- audit_log = AuditLog(
218
- user_id=user.user_id,
219
- action="reset",
220
- ip_address=ip,
221
- status="success"
222
- )
223
- db.add(audit_log)
224
- await db.commit()
225
-
226
- # Send Email
227
- email_body = f"""Hello,
228
-
229
- You requested a secret key reset.
230
-
231
- Your NEW secret key is: {new_secret_key}
232
-
233
- Your old secret key is now invalid.
234
-
235
- If you didn't request this, please
236
- contact support immediately.
237
-
238
- Current Credits: {user.credits}"""
239
-
240
- background_tasks.add_task(
241
- send_email,
242
- request.email,
243
- "Your New Secret Key",
244
- email_body
245
- )
246
-
247
- # Sync DB to Drive (Async)
248
- background_tasks.add_task(drive_service.upload_db)
249
-
250
- else:
251
- # Log Audit (failed/not found)
252
- audit_log = AuditLog(
253
- user_id=None,
254
- action="reset",
255
- ip_address=ip,
256
- status="failed",
257
- error_message="Email not found"
258
- )
259
- db.add(audit_log)
260
- await db.commit()
261
-
262
- # Always return success
263
- return {"success": True, "message": "If this email is registered, reset instructions have been sent."}
 
1
+ """
2
+ Authentication Router - Google OAuth
3
+
4
+ Endpoints for Google Sign-In authentication flow.
5
+ No more secret keys - users authenticate with their Google account.
6
+ """
7
+ from fastapi import APIRouter, Depends, HTTPException, status, Request, BackgroundTasks
8
  from fastapi.responses import JSONResponse
9
  from sqlalchemy.ext.asyncio import AsyncSession
10
  from sqlalchemy import select
11
  from datetime import datetime
12
  import uuid
13
+ import logging
14
 
15
  from core.database import get_db
16
  from core.models import User, AuditLog
17
+ from core.schemas import (
18
+ CheckRegistrationRequest,
19
+ GoogleAuthRequest,
20
+ AuthResponse,
21
+ UserInfoResponse,
22
+ TokenRefreshRequest,
23
+ TokenRefreshResponse
24
+ )
25
+ from services.google_auth_service import (
26
+ GoogleAuthService,
27
+ InvalidTokenError as GoogleInvalidTokenError,
28
+ ConfigurationError as GoogleConfigError,
29
+ get_google_auth_service
30
+ )
31
+ from services.jwt_service import (
32
+ JWTService,
33
+ create_access_token,
34
+ get_jwt_service,
35
+ InvalidTokenError as JWTInvalidTokenError
36
+ )
37
+ from dependencies import check_rate_limit, get_current_user
38
  from services.drive_service import DriveService
39
 
40
+ logger = logging.getLogger(__name__)
41
+
42
  router = APIRouter(prefix="/auth", tags=["auth"])
43
  drive_service = DriveService()
44
 
45
+
46
  @router.post("/check-registration")
47
  async def check_registration(
48
  request: CheckRegistrationRequest,
 
51
  ):
52
  """
53
  Check if a temporary user_id has completed registration.
54
+ Useful for frontend to check if user needs to sign in.
55
  """
56
  # Rate Limit: 10 requests per minute per IP
57
  ip = req.client.host
 
65
  return {"is_registered": user is not None}
66
 
67
 
68
+ @router.post("/google", response_model=AuthResponse)
69
+ async def google_auth(
70
+ request: GoogleAuthRequest,
71
  req: Request,
72
  background_tasks: BackgroundTasks,
73
  db: AsyncSession = Depends(get_db)
74
  ):
75
  """
76
+ Authenticate with Google ID token.
77
+
78
+ Frontend flow:
79
+ 1. User clicks "Sign in with Google" button
80
+ 2. Google returns an ID token
81
+ 3. Frontend sends that token to this endpoint
82
+ 4. We verify it with Google and issue our own JWT
83
+
84
+ Creates new user or returns existing user.
85
+ Existing users matched by email.
86
  """
 
87
  ip = req.client.host
88
+
89
+ # Rate Limit: 10 attempts per minute per IP
90
+ if not await check_rate_limit(db, ip, "/auth/google", 10, 1):
91
+ raise HTTPException(
92
+ status_code=status.HTTP_429_TOO_MANY_REQUESTS,
93
+ detail="Too many authentication attempts"
94
+ )
95
+
96
+ # Verify Google token
97
+ try:
98
+ google_service = get_google_auth_service()
99
+ google_info = google_service.verify_token(request.id_token)
100
+ except GoogleConfigError as e:
101
+ logger.error(f"Google Auth not configured: {e}")
102
+ raise HTTPException(
103
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
104
+ detail="Google authentication is not configured"
105
+ )
106
+ except GoogleInvalidTokenError as e:
107
+ logger.warning(f"Invalid Google token from {ip}: {e}")
108
+
109
+ # Log failed attempt
110
+ audit_log = AuditLog(
111
+ user_id=None,
112
+ action="google_auth",
113
+ ip_address=ip,
114
+ status="failed",
115
+ error_message=str(e)
116
+ )
117
+ db.add(audit_log)
118
+ await db.commit()
119
+
120
+ raise HTTPException(
121
+ status_code=status.HTTP_401_UNAUTHORIZED,
122
+ detail="Invalid Google token. Please try signing in again."
123
+ )
124
+
125
+ # Check for existing user by email (preserves credits for migrated users)
126
+ query = select(User).where(User.email == google_info.email)
127
  result = await db.execute(query)
128
+ user = result.scalar_one_or_none()
129
+
130
+ is_new_user = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
+ if user:
133
+ # Existing user - update Google info
134
+ if not user.google_id:
135
+ user.google_id = google_info.google_id
136
+ logger.info(f"Linked Google account to existing user: {user.email}")
137
+
138
+ user.name = google_info.name
139
+ user.profile_picture = google_info.picture
140
+ user.last_used_at = datetime.utcnow()
141
+
142
+ # Link temp_user_id if provided and not already set
143
+ if request.temp_user_id and not user.temp_user_id:
144
+ user.temp_user_id = request.temp_user_id
145
+ else:
146
+ # New user - create account
147
+ is_new_user = True
148
+ user = User(
149
+ user_id="usr_" + str(uuid.uuid4()),
150
+ temp_user_id=request.temp_user_id,
151
+ email=google_info.email,
152
+ google_id=google_info.google_id,
153
+ name=google_info.name,
154
+ profile_picture=google_info.picture,
155
+ credits=100
156
+ )
157
+ db.add(user)
158
+ logger.info(f"New user created via Google: {google_info.email}")
159
+
160
+ # Log successful auth
161
  audit_log = AuditLog(
162
+ user_id=user.user_id,
163
+ action="google_auth",
164
  ip_address=ip,
165
  status="success"
166
  )
167
  db.add(audit_log)
 
168
  await db.commit()
169
+
170
+ # Create our JWT access token
171
+ access_token = create_access_token(user.user_id, user.email)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
  # Sync DB to Drive (Async)
174
  background_tasks.add_task(drive_service.upload_db)
175
+
176
+ return AuthResponse(
177
+ success=True,
178
+ access_token=access_token,
179
+ user_id=user.user_id,
180
+ email=user.email,
181
+ name=user.name,
182
+ credits=user.credits,
183
+ is_new_user=is_new_user
184
+ )
185
 
 
186
 
187
+ @router.get("/me", response_model=UserInfoResponse)
188
+ async def get_current_user_info(
189
+ user: User = Depends(get_current_user)
190
+ ):
191
+ """
192
+ Get current authenticated user info.
193
+
194
+ Requires Authorization: Bearer <token> header.
195
+ """
196
+ return UserInfoResponse(
197
+ user_id=user.user_id,
198
+ email=user.email,
199
+ name=user.name,
200
+ credits=user.credits,
201
+ profile_picture=user.profile_picture
202
+ )
203
 
204
+
205
+ @router.post("/refresh", response_model=TokenRefreshResponse)
206
+ async def refresh_token(
207
+ request: TokenRefreshRequest,
208
  req: Request,
 
 
209
  db: AsyncSession = Depends(get_db)
210
  ):
211
  """
212
+ Refresh an access token.
213
+
214
+ Use this when the current token is about to expire
215
+ (or has recently expired) to get a new one without
216
+ requiring the user to sign in again.
217
  """
 
218
  ip = req.client.host
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
+ # Rate Limit: 5 refreshes per minute per IP
221
+ if not await check_rate_limit(db, ip, "/auth/refresh", 5, 1):
222
+ raise HTTPException(
223
+ status_code=status.HTTP_429_TOO_MANY_REQUESTS,
224
+ detail="Too many refresh attempts"
 
 
 
 
 
225
  )
226
+
227
+ try:
228
+ jwt_service = get_jwt_service()
229
+ new_token = jwt_service.refresh_token(request.token)
 
230
 
231
+ return TokenRefreshResponse(
232
+ success=True,
233
+ access_token=new_token
 
 
 
 
 
 
 
 
 
 
 
234
  )
235
+ except JWTInvalidTokenError as e:
236
+ raise HTTPException(
 
 
237
  status_code=status.HTTP_401_UNAUTHORIZED,
238
+ detail=f"Cannot refresh token: {str(e)}"
239
  )
240
 
241
 
242
+ @router.post("/logout")
243
+ async def logout(
 
244
  req: Request,
245
  background_tasks: BackgroundTasks,
246
+ user: User = Depends(get_current_user),
247
  db: AsyncSession = Depends(get_db)
248
  ):
249
  """
250
+ Logout current user.
251
+
252
+ Note: JWT tokens are stateless, so this endpoint mainly
253
+ serves to log the action. Frontend should discard the token.
 
 
 
 
 
254
 
255
+ For full session invalidation, consider implementing
256
+ a token blacklist or reducing token expiry times.
257
+ """
258
  ip = req.client.host
259
+
260
+ # Log logout
261
+ audit_log = AuditLog(
262
+ user_id=user.user_id,
263
+ action="logout",
264
+ ip_address=ip,
265
+ status="success"
266
+ )
267
+ db.add(audit_log)
268
+ await db.commit()
269
+
270
+ # Sync DB to Drive (Async)
271
+ background_tasks.add_task(drive_service.upload_db)
272
+
273
+ return {"success": True, "message": "Logged out successfully"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
routers/gemini.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gemini Router - API endpoints for Gemini AI services.
3
+ Uses job queue pattern for async processing.
4
+ Authentication via JWT (Authorization: Bearer <token>).
5
+ """
6
+ import os
7
+ import uuid
8
+ from fastapi import APIRouter, Depends, HTTPException, status
9
+ from fastapi.responses import FileResponse
10
+ from pydantic import BaseModel, Field
11
+ from typing import Optional, Literal
12
+ from sqlalchemy.ext.asyncio import AsyncSession
13
+ from sqlalchemy import select, func
14
+
15
+ from core.database import get_db
16
+ from core.models import User, GeminiJob
17
+ from services.gemini_service import MODELS, DOWNLOADS_DIR
18
+ from dependencies import verify_credits, get_current_user
19
+ from datetime import datetime
20
+
21
+ router = APIRouter(prefix="/gemini", tags=["gemini"])
22
+
23
+
24
+ # Request/Response Models
25
+ class GenerateAnimationPromptRequest(BaseModel):
26
+ base64_image: str = Field(..., description="Base64 encoded image data")
27
+ mime_type: str = Field(..., description="MIME type of the image (e.g., image/png)")
28
+ custom_prompt: Optional[str] = Field(None, description="Optional custom prompt for analysis")
29
+
30
+
31
+ class EditImageRequest(BaseModel):
32
+ base64_image: str = Field(..., description="Base64 encoded image data")
33
+ mime_type: str = Field(..., description="MIME type of the image")
34
+ prompt: str = Field(..., description="Edit instructions")
35
+
36
+
37
+ class GenerateVideoRequest(BaseModel):
38
+ base64_image: str = Field(..., description="Base64 encoded image data")
39
+ mime_type: str = Field(..., description="MIME type of the image")
40
+ prompt: str = Field(..., description="Video generation prompt")
41
+ aspect_ratio: Literal["16:9", "9:16"] = Field("16:9", description="Video aspect ratio")
42
+ resolution: Literal["720p", "1080p"] = Field("720p", description="Video resolution")
43
+ number_of_videos: int = Field(1, ge=1, le=4, description="Number of videos to generate")
44
+
45
+
46
+ class GenerateTextRequest(BaseModel):
47
+ prompt: str = Field(..., description="Text prompt")
48
+ model: Optional[str] = Field(None, description="Model to use (defaults to gemini-2.5-flash)")
49
+
50
+
51
+ class AnalyzeImageRequest(BaseModel):
52
+ base64_image: str = Field(..., description="Base64 encoded image data")
53
+ mime_type: str = Field(..., description="MIME type of the image")
54
+ prompt: str = Field(..., description="Analysis prompt")
55
+
56
+
57
+ # Authentication is handled by dependencies.verify_credits
58
+ # which uses JWT tokens from Authorization: Bearer <token> header
59
+
60
+
61
+ async def get_queue_position(db: AsyncSession, job_id: str) -> int:
62
+ """Get the position of a job in the queue."""
63
+ query = select(func.count()).where(
64
+ GeminiJob.status == "queued",
65
+ GeminiJob.created_at < select(GeminiJob.created_at).where(GeminiJob.job_id == job_id).scalar_subquery()
66
+ )
67
+ result = await db.execute(query)
68
+ return result.scalar() + 1
69
+
70
+
71
+ async def create_job(
72
+ db: AsyncSession,
73
+ user: User,
74
+ job_type: str,
75
+ input_data: dict
76
+ ) -> GeminiJob:
77
+ """Create a new job in the queue."""
78
+ job_id = f"job_{uuid.uuid4().hex[:16]}"
79
+
80
+ job = GeminiJob(
81
+ job_id=job_id,
82
+ user_id=user.user_id,
83
+ job_type=job_type,
84
+ status="queued",
85
+ input_data=input_data
86
+ )
87
+ db.add(job)
88
+ await db.commit()
89
+ await db.refresh(job)
90
+
91
+ return job
92
+
93
+
94
+ @router.post("/generate-animation-prompt")
95
+ async def generate_animation_prompt(
96
+ request: GenerateAnimationPromptRequest,
97
+ user: User = Depends(verify_credits),
98
+ db: AsyncSession = Depends(get_db)
99
+ ):
100
+ """
101
+ Queue an animation prompt generation job.
102
+ """
103
+ job = await create_job(
104
+ db=db,
105
+ user=user,
106
+ job_type="animation_prompt",
107
+ input_data={
108
+ "base64_image": request.base64_image,
109
+ "mime_type": request.mime_type,
110
+ "custom_prompt": request.custom_prompt
111
+ }
112
+ )
113
+
114
+ position = await get_queue_position(db, job.job_id)
115
+
116
+ return {
117
+ "success": True,
118
+ "job_id": job.job_id,
119
+ "status": "queued",
120
+ "position": position,
121
+ "credits_remaining": user.credits
122
+ }
123
+
124
+
125
+ @router.post("/edit-image")
126
+ async def edit_image(
127
+ request: EditImageRequest,
128
+ user: User = Depends(verify_credits),
129
+ db: AsyncSession = Depends(get_db)
130
+ ):
131
+ """
132
+ Queue an image edit job.
133
+ """
134
+ job = await create_job(
135
+ db=db,
136
+ user=user,
137
+ job_type="image",
138
+ input_data={
139
+ "base64_image": request.base64_image,
140
+ "mime_type": request.mime_type,
141
+ "prompt": request.prompt
142
+ }
143
+ )
144
+
145
+ position = await get_queue_position(db, job.job_id)
146
+
147
+ return {
148
+ "success": True,
149
+ "job_id": job.job_id,
150
+ "status": "queued",
151
+ "position": position,
152
+ "credits_remaining": user.credits
153
+ }
154
+
155
+
156
+ @router.post("/generate-video")
157
+ async def generate_video(
158
+ request: GenerateVideoRequest,
159
+ user: User = Depends(verify_credits),
160
+ db: AsyncSession = Depends(get_db)
161
+ ):
162
+ """
163
+ Queue a video generation job.
164
+ """
165
+ job = await create_job(
166
+ db=db,
167
+ user=user,
168
+ job_type="video",
169
+ input_data={
170
+ "base64_image": request.base64_image,
171
+ "mime_type": request.mime_type,
172
+ "prompt": request.prompt,
173
+ "aspect_ratio": request.aspect_ratio,
174
+ "resolution": request.resolution,
175
+ "number_of_videos": request.number_of_videos
176
+ }
177
+ )
178
+
179
+ position = await get_queue_position(db, job.job_id)
180
+
181
+ return {
182
+ "success": True,
183
+ "job_id": job.job_id,
184
+ "status": "queued",
185
+ "position": position,
186
+ "credits_remaining": user.credits
187
+ }
188
+
189
+
190
+ @router.post("/generate-text")
191
+ async def generate_text(
192
+ request: GenerateTextRequest,
193
+ user: User = Depends(verify_credits),
194
+ db: AsyncSession = Depends(get_db)
195
+ ):
196
+ """
197
+ Queue a text generation job.
198
+ """
199
+ job = await create_job(
200
+ db=db,
201
+ user=user,
202
+ job_type="text",
203
+ input_data={
204
+ "prompt": request.prompt,
205
+ "model": request.model
206
+ }
207
+ )
208
+
209
+ position = await get_queue_position(db, job.job_id)
210
+
211
+ return {
212
+ "success": True,
213
+ "job_id": job.job_id,
214
+ "status": "queued",
215
+ "position": position,
216
+ "credits_remaining": user.credits
217
+ }
218
+
219
+
220
+ @router.post("/analyze-image")
221
+ async def analyze_image(
222
+ request: AnalyzeImageRequest,
223
+ user: User = Depends(verify_credits),
224
+ db: AsyncSession = Depends(get_db)
225
+ ):
226
+ """
227
+ Queue an image analysis job.
228
+ """
229
+ job = await create_job(
230
+ db=db,
231
+ user=user,
232
+ job_type="analyze",
233
+ input_data={
234
+ "base64_image": request.base64_image,
235
+ "mime_type": request.mime_type,
236
+ "prompt": request.prompt
237
+ }
238
+ )
239
+
240
+ position = await get_queue_position(db, job.job_id)
241
+
242
+ return {
243
+ "success": True,
244
+ "job_id": job.job_id,
245
+ "status": "queued",
246
+ "position": position,
247
+ "credits_remaining": user.credits
248
+ }
249
+
250
+
251
+ @router.get("/job/{job_id}")
252
+ async def get_job_status(
253
+ job_id: str,
254
+ user: User = Depends(get_current_user),
255
+ db: AsyncSession = Depends(get_db)
256
+ ):
257
+ """
258
+ Get the status of a job.
259
+ Poll this endpoint until status is 'completed' or 'failed'.
260
+ """
261
+ query = select(GeminiJob).where(
262
+ GeminiJob.job_id == job_id,
263
+ GeminiJob.user_id == user.user_id
264
+ )
265
+ result = await db.execute(query)
266
+ job = result.scalar_one_or_none()
267
+
268
+ if not job:
269
+ raise HTTPException(
270
+ status_code=status.HTTP_404_NOT_FOUND,
271
+ detail="Job not found"
272
+ )
273
+
274
+ response = {
275
+ "success": True,
276
+ "job_id": job.job_id,
277
+ "job_type": job.job_type,
278
+ "status": job.status,
279
+ "created_at": job.created_at.isoformat() if job.created_at else None,
280
+ "credits_remaining": user.credits
281
+ }
282
+
283
+ if job.status == "queued":
284
+ response["position"] = await get_queue_position(db, job.job_id)
285
+
286
+ if job.status == "processing":
287
+ response["started_at"] = job.started_at.isoformat() if job.started_at else None
288
+
289
+ if job.status == "completed":
290
+ response["completed_at"] = job.completed_at.isoformat() if job.completed_at else None
291
+ response["output"] = job.output_data
292
+
293
+ # For video jobs, add download URL
294
+ if job.job_type == "video" and job.output_data and job.output_data.get("filename"):
295
+ response["download_url"] = f"/gemini/download/{job.job_id}"
296
+
297
+ if job.status == "failed":
298
+ response["error"] = job.error_message
299
+ response["completed_at"] = job.completed_at.isoformat() if job.completed_at else None
300
+
301
+ return response
302
+
303
+
304
+ @router.get("/download/{job_id}")
305
+ async def download_video(
306
+ job_id: str,
307
+ user: User = Depends(get_current_user),
308
+ db: AsyncSession = Depends(get_db)
309
+ ):
310
+ """
311
+ Download a generated video.
312
+ """
313
+ query = select(GeminiJob).where(
314
+ GeminiJob.job_id == job_id,
315
+ GeminiJob.user_id == user.user_id,
316
+ GeminiJob.job_type == "video"
317
+ )
318
+ result = await db.execute(query)
319
+ job = result.scalar_one_or_none()
320
+
321
+ if not job:
322
+ raise HTTPException(
323
+ status_code=status.HTTP_404_NOT_FOUND,
324
+ detail="Job not found"
325
+ )
326
+
327
+ if job.status != "completed" or not job.output_data or not job.output_data.get("filename"):
328
+ raise HTTPException(
329
+ status_code=status.HTTP_400_BAD_REQUEST,
330
+ detail="Video not ready for download"
331
+ )
332
+
333
+ filepath = os.path.join(DOWNLOADS_DIR, job.output_data["filename"])
334
+ if not os.path.exists(filepath):
335
+ raise HTTPException(
336
+ status_code=status.HTTP_404_NOT_FOUND,
337
+ detail="Video file not found"
338
+ )
339
+
340
+ return FileResponse(
341
+ path=filepath,
342
+ media_type="video/mp4",
343
+ filename=f"video_{job_id}.mp4"
344
+ )
345
+
346
+
347
+ @router.get("/models")
348
+ async def get_models():
349
+ """
350
+ Get available model names.
351
+ """
352
+ return {"models": MODELS}
services/drive_service.py CHANGED
@@ -24,9 +24,10 @@ class DriveService:
24
  def __init__(self):
25
  self.creds = None
26
  self.service = None
27
- self.client_id = os.getenv('GOOGLE_CLIENT_ID')
28
- self.client_secret = os.getenv('GOOGLE_CLIENT_SECRET')
29
- self.refresh_token = os.getenv('GOOGLE_REFRESH_TOKEN')
 
30
 
31
  def authenticate(self):
32
  """Authenticate using the refresh token."""
 
24
  def __init__(self):
25
  self.creds = None
26
  self.service = None
27
+ # Server-side credentials for Drive API
28
+ self.client_id = os.getenv('SERVER_GOOGLE_CLIENT_ID') or os.getenv('GOOGLE_CLIENT_ID')
29
+ self.client_secret = os.getenv('SERVER_GOOGLE_CLIENT_SECRET') or os.getenv('GOOGLE_CLIENT_SECRET')
30
+ self.refresh_token = os.getenv('SERVER_GOOGLE_REFRESH_TOKEN') or os.getenv('GOOGLE_REFRESH_TOKEN')
31
 
32
  def authenticate(self):
33
  """Authenticate using the refresh token."""
services/email_service.py CHANGED
@@ -32,9 +32,10 @@ class GmailService:
32
  def __init__(self):
33
  self.creds = None
34
  self.service = None
35
- self.client_id = os.getenv('GOOGLE_CLIENT_ID')
36
- self.client_secret = os.getenv('GOOGLE_CLIENT_SECRET')
37
- self.refresh_token = os.getenv('GOOGLE_REFRESH_TOKEN')
 
38
  self.sender_email = os.getenv('SMTP_SENDER') or os.getenv('EMAIL_ID')
39
 
40
  def authenticate(self):
 
32
  def __init__(self):
33
  self.creds = None
34
  self.service = None
35
+ # Server-side credentials for Gmail API
36
+ self.client_id = os.getenv('SERVER_GOOGLE_CLIENT_ID') or os.getenv('GOOGLE_CLIENT_ID')
37
+ self.client_secret = os.getenv('SERVER_GOOGLE_CLIENT_SECRET') or os.getenv('GOOGLE_CLIENT_SECRET')
38
+ self.refresh_token = os.getenv('SERVER_GOOGLE_REFRESH_TOKEN') or os.getenv('GOOGLE_REFRESH_TOKEN')
39
  self.sender_email = os.getenv('SMTP_SENDER') or os.getenv('EMAIL_ID')
40
 
41
  def authenticate(self):
services/gemini_service.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gemini AI Service for image and video generation.
3
+ Python port of the TypeScript geminiService.ts
4
+ Uses server-side API key from environment.
5
+ """
6
+ import asyncio
7
+ import logging
8
+ import os
9
+ import uuid
10
+ import httpx
11
+ from typing import Optional, Literal
12
+ from google import genai
13
+ from google.genai import types
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # Model names - easily configurable
18
+ MODELS = {
19
+ "text_generation": "gemini-2.5-flash",
20
+ "image_edit": "gemini-2.5-flash-image",
21
+ "video_generation": "veo-3.1-fast-generate-preview"
22
+ }
23
+
24
+ # Type aliases
25
+ AspectRatio = Literal["16:9", "9:16"]
26
+ Resolution = Literal["720p", "1080p"]
27
+
28
+ # Video downloads directory
29
+ DOWNLOADS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "downloads")
30
+
31
+ # Ensure downloads directory exists
32
+ os.makedirs(DOWNLOADS_DIR, exist_ok=True)
33
+
34
+ # Concurrency limits from environment (defaults)
35
+ MAX_CONCURRENT_VIDEOS = int(os.getenv("MAX_CONCURRENT_VIDEOS", "2"))
36
+ MAX_CONCURRENT_IMAGES = int(os.getenv("MAX_CONCURRENT_IMAGES", "5"))
37
+ MAX_CONCURRENT_TEXT = int(os.getenv("MAX_CONCURRENT_TEXT", "10"))
38
+
39
+ # Semaphores for concurrency control
40
+ _video_semaphore: Optional[asyncio.Semaphore] = None
41
+ _image_semaphore: Optional[asyncio.Semaphore] = None
42
+ _text_semaphore: Optional[asyncio.Semaphore] = None
43
+
44
+
45
+ def get_video_semaphore() -> asyncio.Semaphore:
46
+ """Get or create video semaphore."""
47
+ global _video_semaphore
48
+ if _video_semaphore is None:
49
+ _video_semaphore = asyncio.Semaphore(MAX_CONCURRENT_VIDEOS)
50
+ logger.info(f"Video semaphore initialized with limit: {MAX_CONCURRENT_VIDEOS}")
51
+ return _video_semaphore
52
+
53
+
54
+ def get_image_semaphore() -> asyncio.Semaphore:
55
+ """Get or create image semaphore."""
56
+ global _image_semaphore
57
+ if _image_semaphore is None:
58
+ _image_semaphore = asyncio.Semaphore(MAX_CONCURRENT_IMAGES)
59
+ logger.info(f"Image semaphore initialized with limit: {MAX_CONCURRENT_IMAGES}")
60
+ return _image_semaphore
61
+
62
+
63
+ def get_text_semaphore() -> asyncio.Semaphore:
64
+ """Get or create text semaphore."""
65
+ global _text_semaphore
66
+ if _text_semaphore is None:
67
+ _text_semaphore = asyncio.Semaphore(MAX_CONCURRENT_TEXT)
68
+ logger.info(f"Text semaphore initialized with limit: {MAX_CONCURRENT_TEXT}")
69
+ return _text_semaphore
70
+
71
+
72
+ def get_gemini_api_key() -> str:
73
+ """Get Gemini API key from environment."""
74
+ api_key = os.getenv("GEMINI_API_KEY")
75
+ if not api_key:
76
+ raise ValueError("GEMINI_API_KEY environment variable not set")
77
+ return api_key
78
+
79
+
80
+ class GeminiService:
81
+ """
82
+ Gemini AI Service for text, image, and video generation.
83
+ Uses server-side API key from environment.
84
+ """
85
+
86
+ def __init__(self, api_key: Optional[str] = None):
87
+ """Initialize the Gemini client with API key from env or provided."""
88
+ self.api_key = api_key or get_gemini_api_key()
89
+ self.client = genai.Client(api_key=self.api_key)
90
+
91
+ def _handle_api_error(self, error: Exception, context: str):
92
+ """Handle API errors with descriptive messages."""
93
+ msg = str(error)
94
+ if "404" in msg or "NOT_FOUND" in msg or "Requested entity was not found" in msg or "[5," in msg:
95
+ raise ValueError(
96
+ f"Model not found ({context}). Ensure your API key project has access to this model. "
97
+ "Veo requires a paid account."
98
+ )
99
+ raise error
100
+
101
+ async def generate_animation_prompt(
102
+ self,
103
+ base64_image: str,
104
+ mime_type: str,
105
+ custom_prompt: Optional[str] = None
106
+ ) -> str:
107
+ """
108
+ Analyzes the image to generate a suitable animation prompt.
109
+ """
110
+ default_prompt = custom_prompt or "Describe how this image could be subtly animated with cinematic movement."
111
+ async with get_text_semaphore():
112
+ try:
113
+ response = await asyncio.to_thread(
114
+ self.client.models.generate_content,
115
+ model=MODELS["text_generation"],
116
+ contents=types.Content(
117
+ parts=[
118
+ types.Part.from_bytes(
119
+ data=base64_image,
120
+ mime_type=mime_type
121
+ ),
122
+ types.Part.from_text(text=default_prompt)
123
+ ]
124
+ )
125
+ )
126
+ return response.text or "Cinematic subtle movement"
127
+ except Exception as error:
128
+ self._handle_api_error(error, MODELS["text_generation"])
129
+
130
+ async def edit_image(
131
+ self,
132
+ base64_image: str,
133
+ mime_type: str,
134
+ prompt: str
135
+ ) -> str:
136
+ """
137
+ Edit an image using Gemini image model.
138
+ Returns base64 data URI of the edited image.
139
+ """
140
+ async with get_image_semaphore():
141
+ try:
142
+ response = await asyncio.to_thread(
143
+ self.client.models.generate_content,
144
+ model=MODELS["image_edit"],
145
+ contents=types.Content(
146
+ parts=[
147
+ types.Part.from_bytes(
148
+ data=base64_image,
149
+ mime_type=mime_type
150
+ ),
151
+ types.Part.from_text(text=prompt or "Enhance this image")
152
+ ]
153
+ )
154
+ )
155
+
156
+ candidates = response.candidates
157
+ if not candidates:
158
+ raise ValueError("No candidates returned from Gemini.")
159
+
160
+ for part in candidates[0].content.parts:
161
+ if hasattr(part, 'inline_data') and part.inline_data and part.inline_data.data:
162
+ result_mime = part.inline_data.mime_type or 'image/png'
163
+ return f"data:{result_mime};base64,{part.inline_data.data}"
164
+
165
+ raise ValueError("No image data found in the response.")
166
+ except Exception as error:
167
+ self._handle_api_error(error, MODELS["image_edit"])
168
+
169
+ async def start_video_generation(
170
+ self,
171
+ base64_image: str,
172
+ mime_type: str,
173
+ prompt: str,
174
+ aspect_ratio: AspectRatio = "16:9",
175
+ resolution: Resolution = "720p",
176
+ number_of_videos: int = 1
177
+ ) -> dict:
178
+ """
179
+ Start video generation using Veo model.
180
+ Returns operation details for polling.
181
+ """
182
+ async with get_video_semaphore():
183
+ try:
184
+ # Start video generation
185
+ operation = await asyncio.to_thread(
186
+ self.client.models.generate_videos,
187
+ model=MODELS["video_generation"],
188
+ prompt=prompt,
189
+ image=types.Image(
190
+ image_bytes=base64_image,
191
+ mime_type=mime_type
192
+ ),
193
+ config=types.GenerateVideosConfig(
194
+ number_of_videos=number_of_videos,
195
+ resolution=resolution,
196
+ aspect_ratio=aspect_ratio
197
+ )
198
+ )
199
+
200
+ # Return operation details
201
+ return {
202
+ "gemini_operation_name": operation.name,
203
+ "done": operation.done,
204
+ "status": "completed" if operation.done else "pending"
205
+ }
206
+ except Exception as error:
207
+ self._handle_api_error(error, MODELS["video_generation"])
208
+
209
+ async def check_video_status(self, gemini_operation_name: str) -> dict:
210
+ """
211
+ Check the status of a video generation operation.
212
+ Returns status and video URL if complete.
213
+ """
214
+ try:
215
+ # Get operation status
216
+ operation = await asyncio.to_thread(
217
+ self.client.operations.get,
218
+ name=gemini_operation_name
219
+ )
220
+
221
+ if not operation.done:
222
+ return {
223
+ "gemini_operation_name": gemini_operation_name,
224
+ "done": False,
225
+ "status": "pending"
226
+ }
227
+
228
+ if operation.error:
229
+ return {
230
+ "gemini_operation_name": gemini_operation_name,
231
+ "done": True,
232
+ "status": "failed",
233
+ "error": operation.error.message or "Unknown error"
234
+ }
235
+
236
+ # Extract video URI
237
+ generated_videos = getattr(operation.response, 'generated_videos', None)
238
+ if not generated_videos:
239
+ return {
240
+ "gemini_operation_name": gemini_operation_name,
241
+ "done": True,
242
+ "status": "failed",
243
+ "error": "No video URI returned. May be due to safety filters."
244
+ }
245
+
246
+ video_uri = generated_videos[0].video.uri if generated_videos[0].video else None
247
+ if not video_uri:
248
+ return {
249
+ "gemini_operation_name": gemini_operation_name,
250
+ "done": True,
251
+ "status": "failed",
252
+ "error": "No video URI in response."
253
+ }
254
+
255
+ # Return success with video URL (internal - will be downloaded by router)
256
+ return {
257
+ "gemini_operation_name": gemini_operation_name,
258
+ "done": True,
259
+ "status": "completed",
260
+ "video_url": f"{video_uri}&key={self.api_key}"
261
+ }
262
+
263
+ except Exception as error:
264
+ msg = str(error)
265
+ if "404" in msg or "NOT_FOUND" in msg or "Requested entity was not found" in msg:
266
+ return {
267
+ "gemini_operation_name": gemini_operation_name,
268
+ "done": True,
269
+ "status": "failed",
270
+ "error": "Operation not found (404). It may have expired."
271
+ }
272
+ raise error
273
+
274
+ async def download_video(self, video_url: str, operation_id: str) -> str:
275
+ """
276
+ Download video from Gemini to local storage.
277
+ Returns the local filename.
278
+ """
279
+ filename = f"{operation_id}.mp4"
280
+ filepath = os.path.join(DOWNLOADS_DIR, filename)
281
+
282
+ try:
283
+ async with httpx.AsyncClient(timeout=120.0) as client:
284
+ response = await client.get(video_url)
285
+ response.raise_for_status()
286
+
287
+ with open(filepath, 'wb') as f:
288
+ f.write(response.content)
289
+
290
+ logger.info(f"Downloaded video to {filepath}")
291
+ return filename
292
+ except Exception as e:
293
+ logger.error(f"Failed to download video: {e}")
294
+ raise ValueError(f"Failed to download video: {e}")
295
+
296
+ async def generate_text(
297
+ self,
298
+ prompt: str,
299
+ model: Optional[str] = None
300
+ ) -> str:
301
+ """
302
+ Simple text generation with Gemini.
303
+ """
304
+ model_name = model or MODELS["text_generation"]
305
+ async with get_text_semaphore():
306
+ try:
307
+ response = await asyncio.to_thread(
308
+ self.client.models.generate_content,
309
+ model=model_name,
310
+ contents=types.Content(
311
+ parts=[types.Part.from_text(text=prompt)]
312
+ )
313
+ )
314
+ return response.text or ""
315
+ except Exception as error:
316
+ self._handle_api_error(error, model_name)
317
+
318
+ async def analyze_image(
319
+ self,
320
+ base64_image: str,
321
+ mime_type: str,
322
+ prompt: str
323
+ ) -> str:
324
+ """
325
+ Analyze image with custom prompt.
326
+ """
327
+ async with get_text_semaphore():
328
+ try:
329
+ response = await asyncio.to_thread(
330
+ self.client.models.generate_content,
331
+ model=MODELS["text_generation"],
332
+ contents=types.Content(
333
+ parts=[
334
+ types.Part.from_bytes(
335
+ data=base64_image,
336
+ mime_type=mime_type
337
+ ),
338
+ types.Part.from_text(text=prompt)
339
+ ]
340
+ )
341
+ )
342
+ return response.text or ""
343
+ except Exception as error:
344
+ self._handle_api_error(error, MODELS["text_generation"])
services/google_auth_service.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modular Google OAuth Service
3
+
4
+ A self-contained, plug-and-play service for verifying Google ID tokens.
5
+ Can be used in any Python application with minimal configuration.
6
+
7
+ Usage:
8
+ from services.google_auth_service import GoogleAuthService, GoogleUserInfo
9
+
10
+ # Initialize with client ID
11
+ auth_service = GoogleAuthService(client_id="your-google-client-id")
12
+
13
+ # Or use environment variable GOOGLE_CLIENT_ID
14
+ auth_service = GoogleAuthService()
15
+
16
+ # Verify a Google ID token
17
+ user_info = auth_service.verify_token(id_token)
18
+ print(user_info.email, user_info.google_id, user_info.name)
19
+
20
+ Environment Variables:
21
+ GOOGLE_CLIENT_ID: Your Google OAuth 2.0 Client ID
22
+
23
+ Dependencies:
24
+ google-auth>=2.0.0
25
+ google-auth-oauthlib>=1.0.0
26
+ """
27
+
28
+ import os
29
+ import logging
30
+ from dataclasses import dataclass
31
+ from typing import Optional
32
+ from google.oauth2 import id_token as google_id_token
33
+ from google.auth.transport import requests as google_requests
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+
38
+ @dataclass
39
+ class GoogleUserInfo:
40
+ """
41
+ User information extracted from a verified Google ID token.
42
+
43
+ Attributes:
44
+ google_id: Unique Google user identifier (sub claim)
45
+ email: User's email address
46
+ email_verified: Whether Google has verified the email
47
+ name: User's display name (may be None)
48
+ picture: URL to user's profile picture (may be None)
49
+ given_name: User's first name (may be None)
50
+ family_name: User's last name (may be None)
51
+ locale: User's locale preference (may be None)
52
+ """
53
+ google_id: str
54
+ email: str
55
+ email_verified: bool = True
56
+ name: Optional[str] = None
57
+ picture: Optional[str] = None
58
+ given_name: Optional[str] = None
59
+ family_name: Optional[str] = None
60
+ locale: Optional[str] = None
61
+
62
+
63
+ class GoogleAuthError(Exception):
64
+ """Base exception for Google Auth errors."""
65
+ pass
66
+
67
+
68
+ class InvalidTokenError(GoogleAuthError):
69
+ """Raised when the token is invalid or expired."""
70
+ pass
71
+
72
+
73
+ class ConfigurationError(GoogleAuthError):
74
+ """Raised when the service is not properly configured."""
75
+ pass
76
+
77
+
78
+ class GoogleAuthService:
79
+ """
80
+ Service for verifying Google OAuth ID tokens.
81
+
82
+ This service validates ID tokens issued by Google Sign-In and extracts
83
+ user information. It's designed to be modular and reusable across
84
+ different applications.
85
+
86
+ Example:
87
+ service = GoogleAuthService()
88
+ try:
89
+ user_info = service.verify_token(token_from_frontend)
90
+ print(f"Welcome {user_info.name}!")
91
+ except InvalidTokenError:
92
+ print("Invalid or expired token")
93
+ """
94
+
95
+ def __init__(
96
+ self,
97
+ client_id: Optional[str] = None,
98
+ clock_skew_seconds: int = 0
99
+ ):
100
+ """
101
+ Initialize the Google Auth Service.
102
+
103
+ Args:
104
+ client_id: Google OAuth 2.0 Client ID. If not provided,
105
+ falls back to GOOGLE_CLIENT_ID environment variable.
106
+ clock_skew_seconds: Allowed clock skew in seconds for token
107
+ validation (default: 0).
108
+
109
+ Raises:
110
+ ConfigurationError: If no client_id is provided or found.
111
+ """
112
+ self.client_id = client_id or os.getenv("AUTH_SIGN_IN_GOOGLE_CLIENT_ID")
113
+ self.clock_skew_seconds = clock_skew_seconds
114
+
115
+ if not self.client_id:
116
+ raise ConfigurationError(
117
+ "Google Client ID is required. Either pass client_id parameter "
118
+ "or set GOOGLE_CLIENT_ID environment variable."
119
+ )
120
+
121
+ logger.info(f"GoogleAuthService initialized with client_id: {self.client_id[:20]}...")
122
+
123
+ def verify_token(self, id_token: str) -> GoogleUserInfo:
124
+ """
125
+ Verify a Google ID token and extract user information.
126
+
127
+ Args:
128
+ id_token: The ID token received from the frontend after
129
+ Google Sign-In.
130
+
131
+ Returns:
132
+ GoogleUserInfo: Dataclass containing user's Google profile info.
133
+
134
+ Raises:
135
+ InvalidTokenError: If the token is invalid, expired, or
136
+ doesn't match the expected client ID.
137
+ """
138
+ if not id_token:
139
+ raise InvalidTokenError("Token cannot be empty")
140
+
141
+ try:
142
+ # Verify the token with Google
143
+ idinfo = google_id_token.verify_oauth2_token(
144
+ id_token,
145
+ google_requests.Request(),
146
+ self.client_id,
147
+ clock_skew_in_seconds=self.clock_skew_seconds
148
+ )
149
+
150
+ # Validate issuer
151
+ if idinfo.get("iss") not in ["accounts.google.com", "https://accounts.google.com"]:
152
+ raise InvalidTokenError("Invalid token issuer")
153
+
154
+ # Validate audience
155
+ if idinfo.get("aud") != self.client_id:
156
+ raise InvalidTokenError("Token was not issued for this application")
157
+
158
+ # Extract user info
159
+ return GoogleUserInfo(
160
+ google_id=idinfo["sub"],
161
+ email=idinfo["email"],
162
+ email_verified=idinfo.get("email_verified", False),
163
+ name=idinfo.get("name"),
164
+ picture=idinfo.get("picture"),
165
+ given_name=idinfo.get("given_name"),
166
+ family_name=idinfo.get("family_name"),
167
+ locale=idinfo.get("locale")
168
+ )
169
+
170
+ except ValueError as e:
171
+ logger.warning(f"Token verification failed: {e}")
172
+ raise InvalidTokenError(f"Token verification failed: {str(e)}")
173
+ except Exception as e:
174
+ logger.error(f"Unexpected error during token verification: {e}")
175
+ raise InvalidTokenError(f"Token verification error: {str(e)}")
176
+
177
+ def verify_token_safe(self, id_token: str) -> Optional[GoogleUserInfo]:
178
+ """
179
+ Verify a Google ID token without raising exceptions.
180
+
181
+ Useful for cases where you want to check validity without
182
+ exception handling.
183
+
184
+ Args:
185
+ id_token: The ID token to verify.
186
+
187
+ Returns:
188
+ GoogleUserInfo if valid, None if invalid.
189
+ """
190
+ try:
191
+ return self.verify_token(id_token)
192
+ except GoogleAuthError:
193
+ return None
194
+
195
+
196
+ # Singleton instance for convenience (initialized on first use)
197
+ _default_service: Optional[GoogleAuthService] = None
198
+
199
+
200
+ def get_google_auth_service() -> GoogleAuthService:
201
+ """
202
+ Get the default GoogleAuthService instance.
203
+
204
+ Creates a singleton instance using environment variables.
205
+
206
+ Returns:
207
+ GoogleAuthService: The default service instance.
208
+
209
+ Raises:
210
+ ConfigurationError: If GOOGLE_CLIENT_ID is not set.
211
+ """
212
+ global _default_service
213
+ if _default_service is None:
214
+ _default_service = GoogleAuthService()
215
+ return _default_service
216
+
217
+
218
+ def verify_google_token(id_token: str) -> GoogleUserInfo:
219
+ """
220
+ Convenience function to verify a token using the default service.
221
+
222
+ Args:
223
+ id_token: The Google ID token to verify.
224
+
225
+ Returns:
226
+ GoogleUserInfo: Verified user information.
227
+
228
+ Raises:
229
+ InvalidTokenError: If verification fails.
230
+ ConfigurationError: If service is not configured.
231
+ """
232
+ return get_google_auth_service().verify_token(id_token)
services/job_worker.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Background worker for processing Gemini jobs.
3
+ Runs continuously, picks up queued jobs, processes them.
4
+ """
5
+ import asyncio
6
+ import logging
7
+ import os
8
+ from datetime import datetime
9
+ from sqlalchemy import select
10
+ from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
11
+
12
+ from core.database import DATABASE_URL
13
+ from core.models import GeminiJob
14
+ from services.gemini_service import GeminiService, DOWNLOADS_DIR
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ # Worker configuration
19
+ WORKER_POLL_INTERVAL = int(os.getenv("WORKER_POLL_INTERVAL", "5")) # seconds
20
+ MAX_CONCURRENT_VIDEO_JOBS = int(os.getenv("MAX_CONCURRENT_VIDEO_JOBS", "2"))
21
+ MAX_CONCURRENT_IMAGE_JOBS = int(os.getenv("MAX_CONCURRENT_IMAGE_JOBS", "3"))
22
+ MAX_CONCURRENT_TEXT_JOBS = int(os.getenv("MAX_CONCURRENT_TEXT_JOBS", "5"))
23
+
24
+ # Track running jobs
25
+ _running_jobs = {
26
+ "video": 0,
27
+ "image": 0,
28
+ "text": 0,
29
+ "analyze": 0,
30
+ "animation_prompt": 0
31
+ }
32
+ _lock = asyncio.Lock()
33
+
34
+
35
+ class JobWorker:
36
+ """Background worker for processing Gemini jobs."""
37
+
38
+ def __init__(self):
39
+ self.engine = create_async_engine(DATABASE_URL, echo=False)
40
+ self.async_session = async_sessionmaker(
41
+ self.engine,
42
+ class_=AsyncSession,
43
+ expire_on_commit=False
44
+ )
45
+ self._running = False
46
+ self._tasks = []
47
+
48
+ async def start(self):
49
+ """Start the worker."""
50
+ self._running = True
51
+ logger.info("Job worker started")
52
+ asyncio.create_task(self._poll_loop())
53
+
54
+ async def stop(self):
55
+ """Stop the worker."""
56
+ self._running = False
57
+ # Wait for running tasks to complete
58
+ if self._tasks:
59
+ await asyncio.gather(*self._tasks, return_exceptions=True)
60
+ logger.info("Job worker stopped")
61
+
62
+ async def _poll_loop(self):
63
+ """Main polling loop."""
64
+ while self._running:
65
+ try:
66
+ await self._process_queued_jobs()
67
+ except Exception as e:
68
+ logger.error(f"Error in worker poll loop: {e}")
69
+ await asyncio.sleep(WORKER_POLL_INTERVAL)
70
+
71
+ async def _process_queued_jobs(self):
72
+ """Find and process queued jobs."""
73
+ async with self.async_session() as session:
74
+ # Get queued jobs
75
+ query = select(GeminiJob).where(
76
+ GeminiJob.status == "queued"
77
+ ).order_by(GeminiJob.created_at)
78
+
79
+ result = await session.execute(query)
80
+ jobs = result.scalars().all()
81
+
82
+ for job in jobs:
83
+ # Check if we can process this job type
84
+ if not await self._can_process(job.job_type):
85
+ continue
86
+
87
+ # Mark as processing
88
+ async with _lock:
89
+ _running_jobs[job.job_type] = _running_jobs.get(job.job_type, 0) + 1
90
+
91
+ try:
92
+ job.status = "processing"
93
+ job.started_at = datetime.utcnow()
94
+ await session.commit()
95
+
96
+ # Process job in background
97
+ task = asyncio.create_task(self._process_job(job.job_id))
98
+ self._tasks.append(task)
99
+ task.add_done_callback(lambda t: self._tasks.remove(t) if t in self._tasks else None)
100
+ except Exception as e:
101
+ logger.error(f"Error starting job {job.job_id}: {e}")
102
+ async with _lock:
103
+ _running_jobs[job.job_type] = max(0, _running_jobs.get(job.job_type, 0) - 1)
104
+
105
+ async def _can_process(self, job_type: str) -> bool:
106
+ """Check if we can process another job of this type."""
107
+ async with _lock:
108
+ current = _running_jobs.get(job_type, 0)
109
+ if job_type == "video":
110
+ return current < MAX_CONCURRENT_VIDEO_JOBS
111
+ elif job_type in ("image", "edit_image"):
112
+ return current < MAX_CONCURRENT_IMAGE_JOBS
113
+ else: # text, analyze, animation_prompt
114
+ return current < MAX_CONCURRENT_TEXT_JOBS
115
+
116
+ async def _process_job(self, job_id: str):
117
+ """Process a single job."""
118
+ async with self.async_session() as session:
119
+ try:
120
+ # Get the job
121
+ query = select(GeminiJob).where(GeminiJob.job_id == job_id)
122
+ result = await session.execute(query)
123
+ job = result.scalar_one_or_none()
124
+
125
+ if not job:
126
+ logger.error(f"Job {job_id} not found")
127
+ return
128
+
129
+ logger.info(f"Processing job {job_id} (type: {job.job_type})")
130
+
131
+ service = GeminiService()
132
+ input_data = job.input_data or {}
133
+
134
+ if job.job_type == "video":
135
+ await self._process_video_job(session, job, service, input_data)
136
+ elif job.job_type == "image":
137
+ await self._process_image_job(session, job, service, input_data)
138
+ elif job.job_type == "text":
139
+ await self._process_text_job(session, job, service, input_data)
140
+ elif job.job_type == "analyze":
141
+ await self._process_analyze_job(session, job, service, input_data)
142
+ elif job.job_type == "animation_prompt":
143
+ await self._process_animation_prompt_job(session, job, service, input_data)
144
+ else:
145
+ job.status = "failed"
146
+ job.error_message = f"Unknown job type: {job.job_type}"
147
+ job.completed_at = datetime.utcnow()
148
+ await session.commit()
149
+
150
+ except Exception as e:
151
+ logger.error(f"Error processing job {job_id}: {e}")
152
+ try:
153
+ job.status = "failed"
154
+ job.error_message = str(e)
155
+ job.completed_at = datetime.utcnow()
156
+ await session.commit()
157
+ except:
158
+ pass
159
+ finally:
160
+ async with _lock:
161
+ job_type = job.job_type if job else "unknown"
162
+ _running_jobs[job_type] = max(0, _running_jobs.get(job_type, 0) - 1)
163
+
164
+ async def _process_video_job(self, session: AsyncSession, job: GeminiJob, service: GeminiService, input_data: dict):
165
+ """Process a video generation job."""
166
+ # Start video generation
167
+ result = await service.start_video_generation(
168
+ base64_image=input_data.get("base64_image", ""),
169
+ mime_type=input_data.get("mime_type", "image/jpeg"),
170
+ prompt=input_data.get("prompt", ""),
171
+ aspect_ratio=input_data.get("aspect_ratio", "16:9"),
172
+ resolution=input_data.get("resolution", "720p"),
173
+ number_of_videos=input_data.get("number_of_videos", 1)
174
+ )
175
+
176
+ # Save third party ID
177
+ job.third_party_id = result.get("gemini_operation_name")
178
+ await session.commit()
179
+
180
+ # Poll until done
181
+ while True:
182
+ status_result = await service.check_video_status(job.third_party_id)
183
+
184
+ if status_result.get("done"):
185
+ if status_result.get("status") == "completed":
186
+ # Download video
187
+ video_url = status_result.get("video_url")
188
+ if video_url:
189
+ filename = await service.download_video(video_url, job.job_id)
190
+ job.status = "completed"
191
+ job.output_data = {"filename": filename}
192
+ else:
193
+ job.status = "failed"
194
+ job.error_message = "No video URL returned"
195
+ else:
196
+ job.status = "failed"
197
+ job.error_message = status_result.get("error", "Unknown error")
198
+
199
+ job.completed_at = datetime.utcnow()
200
+ await session.commit()
201
+ break
202
+
203
+ await asyncio.sleep(10) # Poll every 10 seconds
204
+
205
+ async def _process_image_job(self, session: AsyncSession, job: GeminiJob, service: GeminiService, input_data: dict):
206
+ """Process an image edit job."""
207
+ result = await service.edit_image(
208
+ base64_image=input_data.get("base64_image", ""),
209
+ mime_type=input_data.get("mime_type", "image/jpeg"),
210
+ prompt=input_data.get("prompt", "")
211
+ )
212
+
213
+ job.status = "completed"
214
+ job.output_data = {"image": result}
215
+ job.completed_at = datetime.utcnow()
216
+ await session.commit()
217
+
218
+ async def _process_text_job(self, session: AsyncSession, job: GeminiJob, service: GeminiService, input_data: dict):
219
+ """Process a text generation job."""
220
+ result = await service.generate_text(
221
+ prompt=input_data.get("prompt", ""),
222
+ model=input_data.get("model")
223
+ )
224
+
225
+ job.status = "completed"
226
+ job.output_data = {"text": result}
227
+ job.completed_at = datetime.utcnow()
228
+ await session.commit()
229
+
230
+ async def _process_analyze_job(self, session: AsyncSession, job: GeminiJob, service: GeminiService, input_data: dict):
231
+ """Process an image analysis job."""
232
+ result = await service.analyze_image(
233
+ base64_image=input_data.get("base64_image", ""),
234
+ mime_type=input_data.get("mime_type", "image/jpeg"),
235
+ prompt=input_data.get("prompt", "")
236
+ )
237
+
238
+ job.status = "completed"
239
+ job.output_data = {"analysis": result}
240
+ job.completed_at = datetime.utcnow()
241
+ await session.commit()
242
+
243
+ async def _process_animation_prompt_job(self, session: AsyncSession, job: GeminiJob, service: GeminiService, input_data: dict):
244
+ """Process an animation prompt generation job."""
245
+ result = await service.generate_animation_prompt(
246
+ base64_image=input_data.get("base64_image", ""),
247
+ mime_type=input_data.get("mime_type", "image/jpeg"),
248
+ custom_prompt=input_data.get("custom_prompt")
249
+ )
250
+
251
+ job.status = "completed"
252
+ job.output_data = {"prompt": result}
253
+ job.completed_at = datetime.utcnow()
254
+ await session.commit()
255
+
256
+
257
+ # Singleton worker instance
258
+ _worker: JobWorker = None
259
+
260
+
261
+ def get_worker() -> JobWorker:
262
+ """Get the global worker instance."""
263
+ global _worker
264
+ if _worker is None:
265
+ _worker = JobWorker()
266
+ return _worker
267
+
268
+
269
+ async def start_worker():
270
+ """Start the background worker."""
271
+ worker = get_worker()
272
+ await worker.start()
273
+
274
+
275
+ async def stop_worker():
276
+ """Stop the background worker."""
277
+ worker = get_worker()
278
+ await worker.stop()
services/jwt_service.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modular JWT Service
3
+
4
+ A self-contained, plug-and-play service for creating and verifying JWT tokens.
5
+ Can be used in any Python application with minimal configuration.
6
+
7
+ Usage:
8
+ from services.jwt_service import JWTService, TokenPayload
9
+
10
+ # Initialize with secret key
11
+ jwt_service = JWTService(secret_key="your-secret-key")
12
+
13
+ # Or use environment variable JWT_SECRET
14
+ jwt_service = JWTService()
15
+
16
+ # Create a token
17
+ token = jwt_service.create_token(user_id="user123", email="[email protected]")
18
+
19
+ # Verify a token
20
+ payload = jwt_service.verify_token(token)
21
+ print(payload.user_id, payload.email)
22
+
23
+ Environment Variables:
24
+ JWT_SECRET: Your secret key for signing tokens (required)
25
+ JWT_EXPIRY_HOURS: Token expiry in hours (default: 168 = 7 days)
26
+ JWT_ALGORITHM: Algorithm to use (default: HS256)
27
+
28
+ Dependencies:
29
+ PyJWT>=2.8.0
30
+
31
+ Generate a secure secret:
32
+ python -c "import secrets; print(secrets.token_urlsafe(64))"
33
+ """
34
+
35
+ import os
36
+ import logging
37
+ from dataclasses import dataclass
38
+ from datetime import datetime, timedelta
39
+ from typing import Optional, Dict, Any
40
+ import jwt
41
+
42
+ logger = logging.getLogger(__name__)
43
+
44
+
45
+ @dataclass
46
+ class TokenPayload:
47
+ """
48
+ Payload extracted from a verified JWT token.
49
+
50
+ Attributes:
51
+ user_id: The user's unique identifier (sub claim)
52
+ email: The user's email address
53
+ issued_at: When the token was issued
54
+ expires_at: When the token expires
55
+ extra: Any additional claims in the token
56
+ """
57
+ user_id: str
58
+ email: str
59
+ issued_at: datetime
60
+ expires_at: datetime
61
+ extra: Dict[str, Any] = None
62
+
63
+ def __post_init__(self):
64
+ if self.extra is None:
65
+ self.extra = {}
66
+
67
+ @property
68
+ def is_expired(self) -> bool:
69
+ """Check if the token has expired."""
70
+ return datetime.utcnow() > self.expires_at
71
+
72
+ @property
73
+ def time_until_expiry(self) -> timedelta:
74
+ """Get time remaining until expiry."""
75
+ return self.expires_at - datetime.utcnow()
76
+
77
+
78
+ class JWTError(Exception):
79
+ """Base exception for JWT errors."""
80
+ pass
81
+
82
+
83
+ class TokenExpiredError(JWTError):
84
+ """Raised when the token has expired."""
85
+ pass
86
+
87
+
88
+ class InvalidTokenError(JWTError):
89
+ """Raised when the token is invalid."""
90
+ pass
91
+
92
+
93
+ class ConfigurationError(JWTError):
94
+ """Raised when the service is not properly configured."""
95
+ pass
96
+
97
+
98
+ class JWTService:
99
+ """
100
+ Service for creating and verifying JWT tokens.
101
+
102
+ This service handles JWT token lifecycle for authentication.
103
+ It's designed to be modular and reusable across different applications.
104
+
105
+ Example:
106
+ service = JWTService(secret_key="my-secret")
107
+
108
+ # Create token
109
+ token = service.create_token(user_id="u123", email="[email protected]")
110
+
111
+ # Verify token
112
+ try:
113
+ payload = service.verify_token(token)
114
+ print(f"User: {payload.user_id}")
115
+ except TokenExpiredError:
116
+ print("Token expired, please login again")
117
+ except InvalidTokenError:
118
+ print("Invalid token")
119
+ """
120
+
121
+ # Default configuration
122
+ DEFAULT_ALGORITHM = "HS256"
123
+ DEFAULT_EXPIRY_HOURS = 168 # 7 days
124
+
125
+ def __init__(
126
+ self,
127
+ secret_key: Optional[str] = None,
128
+ algorithm: Optional[str] = None,
129
+ expiry_hours: Optional[int] = None
130
+ ):
131
+ """
132
+ Initialize the JWT Service.
133
+
134
+ Args:
135
+ secret_key: Secret key for signing tokens. If not provided,
136
+ falls back to JWT_SECRET environment variable.
137
+ algorithm: JWT algorithm (default: HS256).
138
+ expiry_hours: Token expiry in hours (default: 168 = 7 days).
139
+
140
+ Raises:
141
+ ConfigurationError: If no secret_key is provided or found.
142
+ """
143
+ self.secret_key = secret_key or os.getenv("JWT_SECRET")
144
+ self.algorithm = algorithm or os.getenv("JWT_ALGORITHM", self.DEFAULT_ALGORITHM)
145
+ self.expiry_hours = expiry_hours or int(
146
+ os.getenv("JWT_EXPIRY_HOURS", str(self.DEFAULT_EXPIRY_HOURS))
147
+ )
148
+
149
+ if not self.secret_key:
150
+ raise ConfigurationError(
151
+ "JWT secret key is required. Either pass secret_key parameter "
152
+ "or set JWT_SECRET environment variable. "
153
+ "Generate one with: python -c \"import secrets; print(secrets.token_urlsafe(64))\""
154
+ )
155
+
156
+ # Warn if secret is too short
157
+ if len(self.secret_key) < 32:
158
+ logger.warning(
159
+ "JWT secret key is short (< 32 chars). "
160
+ "Consider using a longer secret for better security."
161
+ )
162
+
163
+ logger.info(
164
+ f"JWTService initialized (algorithm={self.algorithm}, "
165
+ f"expiry={self.expiry_hours}h)"
166
+ )
167
+
168
+ def create_token(
169
+ self,
170
+ user_id: str,
171
+ email: str,
172
+ extra_claims: Optional[Dict[str, Any]] = None,
173
+ expiry_hours: Optional[int] = None
174
+ ) -> str:
175
+ """
176
+ Create a JWT token for a user.
177
+
178
+ Args:
179
+ user_id: The user's unique identifier.
180
+ email: The user's email address.
181
+ extra_claims: Additional claims to include in the token.
182
+ expiry_hours: Custom expiry for this token (overrides default).
183
+
184
+ Returns:
185
+ str: The encoded JWT token.
186
+ """
187
+ now = datetime.utcnow()
188
+ expiry = expiry_hours or self.expiry_hours
189
+
190
+ payload = {
191
+ "sub": user_id,
192
+ "email": email,
193
+ "iat": now,
194
+ "exp": now + timedelta(hours=expiry),
195
+ }
196
+
197
+ if extra_claims:
198
+ payload.update(extra_claims)
199
+
200
+ token = jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
201
+
202
+ logger.debug(f"Created token for user_id={user_id}")
203
+ return token
204
+
205
+ def verify_token(self, token: str) -> TokenPayload:
206
+ """
207
+ Verify a JWT token and extract the payload.
208
+
209
+ Args:
210
+ token: The JWT token to verify.
211
+
212
+ Returns:
213
+ TokenPayload: Dataclass containing the verified payload.
214
+
215
+ Raises:
216
+ TokenExpiredError: If the token has expired.
217
+ InvalidTokenError: If the token is invalid or malformed.
218
+ """
219
+ if not token:
220
+ raise InvalidTokenError("Token cannot be empty")
221
+
222
+ try:
223
+ payload = jwt.decode(
224
+ token,
225
+ self.secret_key,
226
+ algorithms=[self.algorithm]
227
+ )
228
+
229
+ # Extract standard claims
230
+ user_id = payload.get("sub")
231
+ email = payload.get("email")
232
+ iat = payload.get("iat")
233
+ exp = payload.get("exp")
234
+
235
+ if not user_id or not email:
236
+ raise InvalidTokenError("Token missing required claims (sub, email)")
237
+
238
+ # Convert timestamps to datetime
239
+ issued_at = datetime.utcfromtimestamp(iat) if isinstance(iat, (int, float)) else iat
240
+ expires_at = datetime.utcfromtimestamp(exp) if isinstance(exp, (int, float)) else exp
241
+
242
+ # Extract extra claims
243
+ standard_claims = {"sub", "email", "iat", "exp"}
244
+ extra = {k: v for k, v in payload.items() if k not in standard_claims}
245
+
246
+ return TokenPayload(
247
+ user_id=user_id,
248
+ email=email,
249
+ issued_at=issued_at,
250
+ expires_at=expires_at,
251
+ extra=extra
252
+ )
253
+
254
+ except jwt.ExpiredSignatureError:
255
+ logger.debug("Token verification failed: expired")
256
+ raise TokenExpiredError("Token has expired")
257
+ except jwt.InvalidTokenError as e:
258
+ logger.debug(f"Token verification failed: {e}")
259
+ raise InvalidTokenError(f"Invalid token: {str(e)}")
260
+ except Exception as e:
261
+ logger.error(f"Unexpected error during token verification: {e}")
262
+ raise InvalidTokenError(f"Token verification error: {str(e)}")
263
+
264
+ def verify_token_safe(self, token: str) -> Optional[TokenPayload]:
265
+ """
266
+ Verify a JWT token without raising exceptions.
267
+
268
+ Args:
269
+ token: The JWT token to verify.
270
+
271
+ Returns:
272
+ TokenPayload if valid, None if invalid or expired.
273
+ """
274
+ try:
275
+ return self.verify_token(token)
276
+ except JWTError:
277
+ return None
278
+
279
+ def refresh_token(
280
+ self,
281
+ token: str,
282
+ expiry_hours: Optional[int] = None
283
+ ) -> str:
284
+ """
285
+ Refresh a token by creating a new one with the same claims.
286
+
287
+ Args:
288
+ token: The current (possibly expired) token.
289
+ expiry_hours: Custom expiry for the new token.
290
+
291
+ Returns:
292
+ str: A new JWT token with updated expiry.
293
+
294
+ Raises:
295
+ InvalidTokenError: If the token is malformed.
296
+ """
297
+ try:
298
+ # Decode without verifying expiry
299
+ payload = jwt.decode(
300
+ token,
301
+ self.secret_key,
302
+ algorithms=[self.algorithm],
303
+ options={"verify_exp": False}
304
+ )
305
+
306
+ user_id = payload.get("sub")
307
+ email = payload.get("email")
308
+
309
+ if not user_id or not email:
310
+ raise InvalidTokenError("Token missing required claims")
311
+
312
+ # Preserve extra claims
313
+ standard_claims = {"sub", "email", "iat", "exp"}
314
+ extra = {k: v for k, v in payload.items() if k not in standard_claims}
315
+
316
+ return self.create_token(
317
+ user_id=user_id,
318
+ email=email,
319
+ extra_claims=extra,
320
+ expiry_hours=expiry_hours
321
+ )
322
+
323
+ except jwt.InvalidTokenError as e:
324
+ raise InvalidTokenError(f"Cannot refresh invalid token: {str(e)}")
325
+
326
+
327
+ # Singleton instance for convenience
328
+ _default_service: Optional[JWTService] = None
329
+
330
+
331
+ def get_jwt_service() -> JWTService:
332
+ """
333
+ Get the default JWTService instance.
334
+
335
+ Creates a singleton instance using environment variables.
336
+
337
+ Returns:
338
+ JWTService: The default service instance.
339
+
340
+ Raises:
341
+ ConfigurationError: If JWT_SECRET is not set.
342
+ """
343
+ global _default_service
344
+ if _default_service is None:
345
+ _default_service = JWTService()
346
+ return _default_service
347
+
348
+
349
+ def create_access_token(user_id: str, email: str, **kwargs) -> str:
350
+ """
351
+ Convenience function to create a token using the default service.
352
+
353
+ Args:
354
+ user_id: The user's unique identifier.
355
+ email: The user's email address.
356
+ **kwargs: Additional arguments passed to create_token.
357
+
358
+ Returns:
359
+ str: The encoded JWT token.
360
+ """
361
+ return get_jwt_service().create_token(user_id, email, **kwargs)
362
+
363
+
364
+ def verify_access_token(token: str) -> TokenPayload:
365
+ """
366
+ Convenience function to verify a token using the default service.
367
+
368
+ Args:
369
+ token: The JWT token to verify.
370
+
371
+ Returns:
372
+ TokenPayload: Verified token payload.
373
+
374
+ Raises:
375
+ TokenExpiredError: If the token has expired.
376
+ InvalidTokenError: If the token is invalid.
377
+ """
378
+ return get_jwt_service().verify_token(token)
tests/conftest.py CHANGED
@@ -1,14 +1,27 @@
1
  import pytest
2
  import os
3
  import sys
 
4
  from fastapi.testclient import TestClient
5
  from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
6
 
 
 
 
 
 
7
  # Add parent directory to path to allow importing app
8
  sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
9
 
10
- from app import app
11
- from core.database import get_db, Base
 
 
 
 
 
 
 
12
 
13
  # Use a file-based SQLite database for testing to ensure persistence
14
  TEST_DATABASE_URL = "sqlite+aiosqlite:///./test_blink_data.db"
@@ -34,16 +47,6 @@ async def db_session(test_engine):
34
 
35
  async with async_session() as session:
36
  yield session
37
-
38
- # We don't drop tables here to allow persistence if needed,
39
- # but for isolation we usually want to.
40
- # However, the previous test relied on persistence for rate limiting.
41
- # Let's keep it simple: we clear data manually if needed or rely on fresh DB per run (session scope engine).
42
- # Actually, for rate limiting test to work across requests, we need persistence.
43
- # But for isolation between tests, we want cleanup.
44
- # The previous test_app.py had cleanup_db fixture. Let's replicate that logic in the test file or here.
45
-
46
- # Let's just yield session here.
47
 
48
  @pytest.fixture(scope="function")
49
  def client(test_engine):
@@ -57,6 +60,12 @@ def client(test_engine):
57
  yield session
58
 
59
  app.dependency_overrides[get_db] = override_get_db
60
- with TestClient(app) as c:
61
- yield c
 
 
 
 
 
62
  app.dependency_overrides.clear()
 
 
1
  import pytest
2
  import os
3
  import sys
4
+ from unittest.mock import patch, MagicMock
5
  from fastapi.testclient import TestClient
6
  from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
7
 
8
+ # Set test environment variables BEFORE importing app
9
+ os.environ["JWT_SECRET"] = "test-secret-key-that-is-long-enough-for-security-purposes"
10
+ os.environ["GOOGLE_CLIENT_ID"] = "test-google-client-id.apps.googleusercontent.com"
11
+ os.environ["RESET_DB"] = "true" # Prevent Drive download during tests
12
+
13
  # Add parent directory to path to allow importing app
14
  sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
15
 
16
+ # Mock the drive service before importing app
17
+ with patch("services.drive_service.DriveService") as mock_drive:
18
+ mock_instance = MagicMock()
19
+ mock_instance.download_db.return_value = False
20
+ mock_instance.upload_db.return_value = True
21
+ mock_drive.return_value = mock_instance
22
+
23
+ from app import app
24
+ from core.database import get_db, Base
25
 
26
  # Use a file-based SQLite database for testing to ensure persistence
27
  TEST_DATABASE_URL = "sqlite+aiosqlite:///./test_blink_data.db"
 
47
 
48
  async with async_session() as session:
49
  yield session
 
 
 
 
 
 
 
 
 
 
50
 
51
  @pytest.fixture(scope="function")
52
  def client(test_engine):
 
60
  yield session
61
 
62
  app.dependency_overrides[get_db] = override_get_db
63
+
64
+ # Mock drive service for the test client
65
+ with patch("routers.auth.drive_service") as mock_auth_drive:
66
+ mock_auth_drive.upload_db.return_value = True
67
+ with TestClient(app) as c:
68
+ yield c
69
+
70
  app.dependency_overrides.clear()
71
+
tests/debug_gemini_service.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Debug script to test Gemini service with API keys from environment.
3
+ Keys should be in GEMINI_KEYS environment variable, comma-separated.
4
+
5
+ Usage:
6
+ GEMINI_KEYS="key1,key2,key3" python tests/debug_gemini_service.py
7
+ """
8
+ import os
9
+ import sys
10
+ import asyncio
11
+ import logging
12
+ import base64
13
+ from dotenv import load_dotenv
14
+
15
+ # Load environment variables
16
+ load_dotenv()
17
+
18
+ # Add parent directory to path
19
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
20
+
21
+ from services.gemini_service import GeminiService, MODELS
22
+
23
+ # Configure logging
24
+ logging.basicConfig(level=logging.INFO)
25
+ logger = logging.getLogger(__name__)
26
+
27
+ # Test image path
28
+ TEST_IMAGE_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "test.jpg")
29
+
30
+
31
+ def load_test_image():
32
+ """Load test image and return base64 + mime type."""
33
+ if not os.path.exists(TEST_IMAGE_PATH):
34
+ logger.error(f"Test image not found: {TEST_IMAGE_PATH}")
35
+ return None, None
36
+
37
+ with open(TEST_IMAGE_PATH, "rb") as f:
38
+ image_data = f.read()
39
+
40
+ base64_image = base64.b64encode(image_data).decode("utf-8")
41
+ mime_type = "image/jpeg"
42
+
43
+ logger.info(f"Loaded test image: {TEST_IMAGE_PATH} ({len(image_data)} bytes)")
44
+ return base64_image, mime_type
45
+
46
+
47
+ async def test_generate_text(service: GeminiService, key_index: int):
48
+ """Test simple text generation."""
49
+ logger.info(f"[Key {key_index}] Testing text generation...")
50
+ try:
51
+ result = await service.generate_text("Say hello in one word.")
52
+ logger.info(f"[Key {key_index}] Text generation result: {result[:100]}...")
53
+ return True
54
+ except Exception as e:
55
+ logger.error(f"[Key {key_index}] Text generation failed: {e}")
56
+ return False
57
+
58
+
59
+ async def test_analyze_image(service: GeminiService, key_index: int, base64_image: str, mime_type: str):
60
+ """Test image analysis."""
61
+ logger.info(f"[Key {key_index}] Testing image analysis...")
62
+ try:
63
+ result = await service.analyze_image(
64
+ base64_image=base64_image,
65
+ mime_type=mime_type,
66
+ prompt="Describe this image in one sentence."
67
+ )
68
+ logger.info(f"[Key {key_index}] Image analysis result: {result[:100]}...")
69
+ return True
70
+ except Exception as e:
71
+ logger.error(f"[Key {key_index}] Image analysis failed: {e}")
72
+ return False
73
+
74
+
75
+ async def test_generate_animation_prompt(service: GeminiService, key_index: int, base64_image: str, mime_type: str):
76
+ """Test animation prompt generation."""
77
+ logger.info(f"[Key {key_index}] Testing animation prompt generation...")
78
+ try:
79
+ result = await service.generate_animation_prompt(
80
+ base64_image=base64_image,
81
+ mime_type=mime_type
82
+ )
83
+ logger.info(f"[Key {key_index}] Animation prompt result: {result[:100]}...")
84
+ return True
85
+ except Exception as e:
86
+ logger.error(f"[Key {key_index}] Animation prompt generation failed: {e}")
87
+ return False
88
+
89
+
90
+ async def test_key(api_key: str, key_index: int, base64_image: str, mime_type: str):
91
+ """Test all basic operations with a single API key."""
92
+ logger.info(f"\n{'='*50}")
93
+ logger.info(f"Testing Key {key_index}: {api_key[:10]}...{api_key[-4:]}")
94
+ logger.info(f"{'='*50}")
95
+
96
+ try:
97
+ service = GeminiService(api_key)
98
+ except Exception as e:
99
+ logger.error(f"[Key {key_index}] Failed to initialize service: {e}")
100
+ return {"key_index": key_index, "valid": False, "error": str(e)}
101
+
102
+ results = {
103
+ "key_index": key_index,
104
+ "key_preview": f"{api_key[:10]}...{api_key[-4:]}",
105
+ "text_generation": await test_generate_text(service, key_index),
106
+ "image_analysis": await test_analyze_image(service, key_index, base64_image, mime_type),
107
+ "animation_prompt": await test_generate_animation_prompt(service, key_index, base64_image, mime_type),
108
+ }
109
+
110
+ results["valid"] = all([
111
+ results["text_generation"],
112
+ results["image_analysis"],
113
+ results["animation_prompt"]
114
+ ])
115
+
116
+ return results
117
+
118
+
119
+ async def main():
120
+ # Load test image
121
+ base64_image, mime_type = load_test_image()
122
+ if not base64_image:
123
+ logger.error("Cannot run tests without test image. Please add test.jpg to project root.")
124
+ return
125
+
126
+ gemini_keys_str = os.getenv("GEMINI_KEYS", "")
127
+
128
+ if not gemini_keys_str:
129
+ logger.error("GEMINI_KEYS environment variable not set.")
130
+ logger.info("Usage: GEMINI_KEYS='key1,key2,key3' python tests/debug_gemini_service.py")
131
+ return
132
+
133
+ keys = [k.strip() for k in gemini_keys_str.split(",") if k.strip()]
134
+
135
+ if not keys:
136
+ logger.error("No valid keys found in GEMINI_KEYS.")
137
+ return
138
+
139
+ logger.info(f"Found {len(keys)} API key(s) to test.")
140
+ logger.info(f"Available models: {MODELS}")
141
+
142
+ all_results = []
143
+ for i, key in enumerate(keys):
144
+ result = await test_key(key, i + 1, base64_image, mime_type)
145
+ all_results.append(result)
146
+
147
+ # Summary
148
+ logger.info(f"\n{'='*50}")
149
+ logger.info("SUMMARY")
150
+ logger.info(f"{'='*50}")
151
+
152
+ valid_count = sum(1 for r in all_results if r.get("valid", False))
153
+ logger.info(f"Valid keys: {valid_count}/{len(keys)}")
154
+
155
+ for result in all_results:
156
+ status = "✓ VALID" if result.get("valid") else "✗ INVALID"
157
+ logger.info(f" Key {result['key_index']}: {status}")
158
+ if not result.get("valid"):
159
+ for test_name in ["text_generation", "image_analysis", "animation_prompt"]:
160
+ if test_name in result and not result[test_name]:
161
+ logger.info(f" - {test_name}: FAILED")
162
+
163
+
164
+ if __name__ == "__main__":
165
+ asyncio.run(main())
tests/test_integration.py CHANGED
@@ -1,25 +1,28 @@
 
 
 
 
 
1
  import pytest
2
- from unittest.mock import patch
3
  import os
4
  from sqlalchemy import text
5
 
 
 
 
 
6
  # Cleanup fixture
7
  @pytest.fixture(autouse=True)
8
  def cleanup_db():
9
  if os.path.exists("./test_blink_data.db"):
10
- # We can't easily delete the file if it's open by the engine in conftest.
11
- # Instead, we should probably truncate tables.
12
  pass
13
  yield
14
- # Cleanup logic if needed
15
 
16
- # We need a way to clear data between tests if we want isolation.
17
- # Since we are using a file-based DB shared across the session (engine),
18
- # we should truncate tables.
19
 
20
  @pytest.fixture(autouse=True)
21
  async def clear_tables(db_session):
22
- # Truncate all tables
23
  async with db_session.begin():
24
  await db_session.execute(text("DELETE FROM users"))
25
  await db_session.execute(text("DELETE FROM rate_limits"))
@@ -27,65 +30,179 @@ async def clear_tables(db_session):
27
  await db_session.execute(text("DELETE FROM blink_data"))
28
  await db_session.commit()
29
 
30
- @patch("services.email_service.send_email")
31
- def test_credit_system_flow(mock_send_email, client):
32
- mock_send_email.return_value = True
33
-
34
- # 1. Register
35
- response = client.post("/auth/register", json={
36
- "user_id": "test-user-1",
37
- "email": "test@example.com"
38
- })
39
- assert response.status_code == 200
40
- assert response.json()["success"] == True
 
 
 
 
 
 
 
 
 
 
41
 
42
- # 2. Check registration
43
- response = client.post("/auth/check-registration", json={"user_id": "test-user-1"})
44
- assert response.status_code == 200
45
- assert response.json()["is_registered"] == True
46
-
47
- # 3. Validate with mocked key (we need to know the key)
48
- # Since we can't easily get the key from the hashed DB, let's mock generate_secret_key
49
- with patch("routers.auth.generate_secret_key", return_value="sk_test_key_1234567890123456789012345"):
50
- # Register user 2
51
- client.post("/auth/register", json={
52
- "user_id": "test-user-2",
53
- "email": "[email protected]"
54
  })
55
 
56
- # Validate
57
- response = client.post("/auth/validate", headers={"X-Secret-Key": "sk_test_key_1234567890123456789012345"})
58
  assert response.status_code == 200
59
- assert response.json()["valid"] == True
60
- assert response.json()["credits"] == 100
61
-
62
- def test_blink_flow(client):
63
- # Test Blink Endpoint
64
- # We need a valid userid format: 20 chars + encrypted data
65
- user_id = "12345678901234567890"
66
- encrypted_data = "some_encrypted_data_base64"
67
- userid_param = user_id + encrypted_data
68
 
69
- response = client.get(f"/blink?userid={userid_param}")
70
- assert response.status_code == 200
71
- data = response.json()
72
- assert data["status"] == "success"
73
- assert data["user_id"] == user_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
- # Verify data stored (via API)
76
- response = client.get("/api/data")
77
- assert response.status_code == 200
78
- items = response.json()["items"]
79
- assert len(items) > 0
80
- assert items[0]["user_id"] == user_id
81
-
82
- @patch("services.email_service.send_email")
83
- def test_rate_limiting(mock_send_email, client):
84
- # 10 requests should succeed
85
- for _ in range(10):
86
- response = client.post("/auth/check-registration", json={"user_id": "rate-limit-test"})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  assert response.status_code == 200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
- # 11th request should fail
90
- response = client.post("/auth/check-registration", json={"user_id": "rate-limit-test"})
91
- assert response.status_code == 429
 
1
+ """
2
+ Integration Tests for Google OAuth Authentication
3
+
4
+ Tests the new Google Sign-In flow, JWT token handling, and API access.
5
+ """
6
  import pytest
7
+ from unittest.mock import patch, MagicMock
8
  import os
9
  from sqlalchemy import text
10
 
11
+ from services.google_auth_service import GoogleUserInfo
12
+ from services.jwt_service import JWTService
13
+
14
+
15
  # Cleanup fixture
16
  @pytest.fixture(autouse=True)
17
  def cleanup_db():
18
  if os.path.exists("./test_blink_data.db"):
 
 
19
  pass
20
  yield
 
21
 
 
 
 
22
 
23
  @pytest.fixture(autouse=True)
24
  async def clear_tables(db_session):
25
+ """Truncate all tables between tests."""
26
  async with db_session.begin():
27
  await db_session.execute(text("DELETE FROM users"))
28
  await db_session.execute(text("DELETE FROM rate_limits"))
 
30
  await db_session.execute(text("DELETE FROM blink_data"))
31
  await db_session.commit()
32
 
33
+
34
+ @pytest.fixture
35
+ def jwt_service():
36
+ """Create a JWT service for testing."""
37
+ return JWTService(secret_key="test-secret-key-for-testing-only")
38
+
39
+
40
+ @pytest.fixture
41
+ def mock_google_user():
42
+ """Mock Google user info."""
43
+ return GoogleUserInfo(
44
+ google_id="google_123456789",
45
+ email="[email protected]",
46
+ email_verified=True,
47
+ name="Test User",
48
+ picture="https://example.com/photo.jpg"
49
+ )
50
+
51
+
52
+ class TestGoogleAuth:
53
+ """Test Google OAuth authentication flow."""
54
 
55
+ @patch("routers.auth.get_google_auth_service")
56
+ def test_google_auth_new_user(self, mock_get_service, client, mock_google_user):
57
+ """Test new user registration via Google."""
58
+ mock_service = MagicMock()
59
+ mock_service.verify_token.return_value = mock_google_user
60
+ mock_get_service.return_value = mock_service
61
+
62
+ response = client.post("/auth/google", json={
63
+ "id_token": "fake-google-token-12345",
64
+ "temp_user_id": "temp-user-abc"
 
 
65
  })
66
 
 
 
67
  assert response.status_code == 200
68
+ data = response.json()
69
+ assert data["success"] == True
70
+ assert data["is_new_user"] == True
71
+ assert data["email"] == "[email protected]"
72
+ assert data["name"] == "Test User"
73
+ assert data["credits"] == 100
74
+ assert "access_token" in data
75
+ assert data["access_token"] != ""
 
76
 
77
+ @patch("routers.auth.get_google_auth_service")
78
+ def test_google_auth_existing_user(self, mock_get_service, client, mock_google_user):
79
+ """Test existing user login via Google."""
80
+ mock_service = MagicMock()
81
+ mock_service.verify_token.return_value = mock_google_user
82
+ mock_get_service.return_value = mock_service
83
+
84
+ # First login - creates user
85
+ response1 = client.post("/auth/google", json={"id_token": "token1"})
86
+ assert response1.status_code == 200
87
+ assert response1.json()["is_new_user"] == True
88
+
89
+ # Second login - same user
90
+ response2 = client.post("/auth/google", json={"id_token": "token2"})
91
+ assert response2.status_code == 200
92
+ data = response2.json()
93
+ assert data["is_new_user"] == False
94
+ assert data["email"] == "[email protected]"
95
+ assert data["credits"] == 100 # Credits preserved
96
 
97
+ @patch("routers.auth.get_google_auth_service")
98
+ def test_google_auth_invalid_token(self, mock_get_service, client):
99
+ """Test handling of invalid Google token."""
100
+ from services.google_auth_service import InvalidTokenError
101
+
102
+ mock_service = MagicMock()
103
+ mock_service.verify_token.side_effect = InvalidTokenError("Invalid token")
104
+ mock_get_service.return_value = mock_service
105
+
106
+ response = client.post("/auth/google", json={"id_token": "invalid-token"})
107
+
108
+ assert response.status_code == 401
109
+ assert "Invalid Google token" in response.json()["detail"]
110
+
111
+
112
+ class TestJWTAuth:
113
+ """Test JWT token authentication."""
114
+
115
+ @patch("routers.auth.get_google_auth_service")
116
+ def test_get_current_user(self, mock_get_service, client, mock_google_user):
117
+ """Test getting current user with JWT."""
118
+ mock_service = MagicMock()
119
+ mock_service.verify_token.return_value = mock_google_user
120
+ mock_get_service.return_value = mock_service
121
+
122
+ # Login to get token
123
+ login_response = client.post("/auth/google", json={"id_token": "token"})
124
+ token = login_response.json()["access_token"]
125
+
126
+ # Get user info
127
+ response = client.get("/auth/me", headers={"Authorization": f"Bearer {token}"})
128
+
129
+ assert response.status_code == 200
130
+ data = response.json()
131
+ assert data["email"] == "[email protected]"
132
+ assert data["credits"] == 100
133
+
134
+ def test_missing_auth_header(self, client):
135
+ """Test request without Authorization header."""
136
+ response = client.get("/auth/me")
137
+ assert response.status_code == 401
138
+ assert "Missing Authorization header" in response.json()["detail"]
139
+
140
+ def test_invalid_token_format(self, client):
141
+ """Test request with invalid token format."""
142
+ response = client.get("/auth/me", headers={"Authorization": "InvalidFormat"})
143
+ assert response.status_code == 401
144
+ assert "Invalid Authorization header format" in response.json()["detail"]
145
+
146
+ def test_invalid_token(self, client):
147
+ """Test request with invalid JWT token."""
148
+ response = client.get("/auth/me", headers={"Authorization": "Bearer invalid.jwt.token"})
149
+ assert response.status_code == 401
150
+
151
+
152
+ class TestCreditSystem:
153
+ """Test credit deduction system."""
154
+
155
+ @patch("routers.auth.get_google_auth_service")
156
+ def test_credit_deduction(self, mock_get_service, client, mock_google_user):
157
+ """Test that credits are deducted when using API."""
158
+ mock_service = MagicMock()
159
+ mock_service.verify_token.return_value = mock_google_user
160
+ mock_get_service.return_value = mock_service
161
+
162
+ # Login
163
+ login_response = client.post("/auth/google", json={"id_token": "token"})
164
+ token = login_response.json()["access_token"]
165
+ initial_credits = login_response.json()["credits"]
166
+
167
+ # Make an API call that deducts credits (would need gemini endpoint mock)
168
+ # For now, just verify user info doesn't deduct credits
169
+ response = client.get("/auth/me", headers={"Authorization": f"Bearer {token}"})
170
+ assert response.json()["credits"] == initial_credits # No deduction for info endpoint
171
+
172
+
173
+ class TestBlinkFlow:
174
+ """Test blink data collection."""
175
+
176
+ def test_blink_flow(self, client):
177
+ """Test Blink endpoint still works."""
178
+ user_id = "12345678901234567890"
179
+ encrypted_data = "some_encrypted_data_base64"
180
+ userid_param = user_id + encrypted_data
181
+
182
+ response = client.get(f"/blink?userid={userid_param}")
183
+ assert response.status_code == 200
184
+ data = response.json()
185
+ assert data["status"] == "success"
186
+ assert data["user_id"] == user_id
187
+
188
+ # Verify data stored
189
+ response = client.get("/api/data")
190
  assert response.status_code == 200
191
+ items = response.json()["items"]
192
+ assert len(items) > 0
193
+ assert items[0]["user_id"] == user_id
194
+
195
+
196
+ class TestRateLimiting:
197
+ """Test rate limiting."""
198
+
199
+ def test_rate_limiting(self, client):
200
+ """Test rate limiting on auth endpoints."""
201
+ # 10 requests should succeed
202
+ for _ in range(10):
203
+ response = client.post("/auth/check-registration", json={"user_id": "rate-limit-test"})
204
+ assert response.status_code == 200
205
 
206
+ # 11th request should fail
207
+ response = client.post("/auth/check-registration", json={"user_id": "rate-limit-test"})
208
+ assert response.status_code == 429