Spaces:
Running
Running
google sign in
Browse files- .gitignore +4 -0
- app.py +12 -1
- core/models.py +35 -1
- core/schemas.py +40 -6
- core/security.py +15 -8
- dependencies.py +85 -30
- docs/CLIENT_INTEGRATION.md +287 -0
- generate_jwt_secret.py +95 -0
- requirements.txt +3 -0
- routers/auth.py +197 -187
- routers/gemini.py +352 -0
- services/drive_service.py +4 -3
- services/email_service.py +4 -3
- services/gemini_service.py +344 -0
- services/google_auth_service.py +232 -0
- services/job_worker.py +278 -0
- services/jwt_service.py +378 -0
- tests/conftest.py +23 -14
- tests/debug_gemini_service.py +165 -0
- tests/test_integration.py +179 -62
.gitignore
CHANGED
|
@@ -47,3 +47,7 @@ build/
|
|
| 47 |
# Private keys (keep in repo but be careful)
|
| 48 |
PRIVATE_KEY.pem
|
| 49 |
client_secret.json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
# Private keys (keep in repo but be careful)
|
| 48 |
PRIVATE_KEY.pem
|
| 49 |
client_secret.json
|
| 50 |
+
test.jpg
|
| 51 |
+
|
| 52 |
+
# Downloaded video files
|
| 53 |
+
downloads/
|
app.py
CHANGED
|
@@ -12,7 +12,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|
| 12 |
from fastapi.responses import JSONResponse
|
| 13 |
|
| 14 |
from core.database import init_db
|
| 15 |
-
from routers import auth, blink, general
|
| 16 |
from services.drive_service import DriveService
|
| 17 |
|
| 18 |
# Configure logging
|
|
@@ -46,8 +46,18 @@ async def lifespan(app: FastAPI):
|
|
| 46 |
|
| 47 |
await init_db()
|
| 48 |
logger.info("Database initialized successfully")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
yield
|
| 50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
# Shutdown: Upload DB to Drive
|
| 52 |
logger.info("Shutdown: Uploading database to Google Drive...")
|
| 53 |
drive_service.upload_db()
|
|
@@ -75,6 +85,7 @@ app.add_middleware(
|
|
| 75 |
app.include_router(general.router)
|
| 76 |
app.include_router(auth.router)
|
| 77 |
app.include_router(blink.router)
|
|
|
|
| 78 |
|
| 79 |
|
| 80 |
@app.exception_handler(Exception)
|
|
|
|
| 12 |
from fastapi.responses import JSONResponse
|
| 13 |
|
| 14 |
from core.database import init_db
|
| 15 |
+
from routers import auth, blink, general, gemini
|
| 16 |
from services.drive_service import DriveService
|
| 17 |
|
| 18 |
# Configure logging
|
|
|
|
| 46 |
|
| 47 |
await init_db()
|
| 48 |
logger.info("Database initialized successfully")
|
| 49 |
+
|
| 50 |
+
# Start background job worker
|
| 51 |
+
from services.job_worker import start_worker, stop_worker
|
| 52 |
+
await start_worker()
|
| 53 |
+
logger.info("Background job worker started")
|
| 54 |
+
|
| 55 |
yield
|
| 56 |
|
| 57 |
+
# Stop background job worker
|
| 58 |
+
await stop_worker()
|
| 59 |
+
logger.info("Background job worker stopped")
|
| 60 |
+
|
| 61 |
# Shutdown: Upload DB to Drive
|
| 62 |
logger.info("Shutdown: Uploading database to Google Drive...")
|
| 63 |
drive_service.upload_db()
|
|
|
|
| 85 |
app.include_router(general.router)
|
| 86 |
app.include_router(auth.router)
|
| 87 |
app.include_router(blink.router)
|
| 88 |
+
app.include_router(gemini.router)
|
| 89 |
|
| 90 |
|
| 91 |
@app.exception_handler(Exception)
|
core/models.py
CHANGED
|
@@ -37,6 +37,7 @@ class BlinkData(Base):
|
|
| 37 |
class User(Base):
|
| 38 |
"""
|
| 39 |
User model for credit system.
|
|
|
|
| 40 |
"""
|
| 41 |
__tablename__ = "users"
|
| 42 |
|
|
@@ -44,7 +45,16 @@ class User(Base):
|
|
| 44 |
user_id = Column(String(50), unique=True, index=True, nullable=False) # Backend generated UUID
|
| 45 |
temp_user_id = Column(String(50), index=True, nullable=True) # From frontend
|
| 46 |
email = Column(String(255), unique=True, index=True, nullable=False)
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
credits = Column(Integer, default=100)
|
| 49 |
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
| 50 |
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
|
|
@@ -55,6 +65,7 @@ class User(Base):
|
|
| 55 |
return f"<User(id={self.id}, email={self.email})>"
|
| 56 |
|
| 57 |
|
|
|
|
| 58 |
class RateLimit(Base):
|
| 59 |
"""
|
| 60 |
Rate limit tracking table.
|
|
@@ -83,3 +94,26 @@ class AuditLog(Base):
|
|
| 83 |
status = Column(String(20), nullable=False)
|
| 84 |
error_message = Column(Text, nullable=True)
|
| 85 |
timestamp = Column(DateTime(timezone=True), server_default=func.now())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
class User(Base):
|
| 38 |
"""
|
| 39 |
User model for credit system.
|
| 40 |
+
Supports both legacy secret key and Google OAuth authentication.
|
| 41 |
"""
|
| 42 |
__tablename__ = "users"
|
| 43 |
|
|
|
|
| 45 |
user_id = Column(String(50), unique=True, index=True, nullable=False) # Backend generated UUID
|
| 46 |
temp_user_id = Column(String(50), index=True, nullable=True) # From frontend
|
| 47 |
email = Column(String(255), unique=True, index=True, nullable=False)
|
| 48 |
+
|
| 49 |
+
# Google OAuth fields
|
| 50 |
+
google_id = Column(String(255), unique=True, index=True, nullable=True) # Google sub claim
|
| 51 |
+
name = Column(String(255), nullable=True) # Display name from Google
|
| 52 |
+
profile_picture = Column(Text, nullable=True) # Google profile picture URL
|
| 53 |
+
|
| 54 |
+
# Legacy field (kept for migration, nullable now)
|
| 55 |
+
secret_key_hash = Column(String(255), nullable=True)
|
| 56 |
+
|
| 57 |
+
# Credits and status
|
| 58 |
credits = Column(Integer, default=100)
|
| 59 |
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
| 60 |
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
|
|
|
|
| 65 |
return f"<User(id={self.id}, email={self.email})>"
|
| 66 |
|
| 67 |
|
| 68 |
+
|
| 69 |
class RateLimit(Base):
|
| 70 |
"""
|
| 71 |
Rate limit tracking table.
|
|
|
|
| 94 |
status = Column(String(20), nullable=False)
|
| 95 |
error_message = Column(Text, nullable=True)
|
| 96 |
timestamp = Column(DateTime(timezone=True), server_default=func.now())
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class GeminiJob(Base):
|
| 100 |
+
"""
|
| 101 |
+
Generic job queue for Gemini operations (video, image, text).
|
| 102 |
+
"""
|
| 103 |
+
__tablename__ = "gemini_jobs"
|
| 104 |
+
|
| 105 |
+
id = Column(Integer, primary_key=True, autoincrement=True, index=True)
|
| 106 |
+
job_id = Column(String(100), unique=True, index=True, nullable=False) # Our ID for client
|
| 107 |
+
user_id = Column(String(50), index=True, nullable=False) # User who requested
|
| 108 |
+
job_type = Column(String(20), index=True, nullable=False) # video, image, text, analyze
|
| 109 |
+
third_party_id = Column(String(255), nullable=True) # Gemini operation name (for video)
|
| 110 |
+
status = Column(String(20), default="queued", index=True) # queued, processing, completed, failed
|
| 111 |
+
input_data = Column(JSON, nullable=True) # Request details (prompt, settings, etc.)
|
| 112 |
+
output_data = Column(JSON, nullable=True) # Result (filename, text, etc.)
|
| 113 |
+
error_message = Column(Text, nullable=True)
|
| 114 |
+
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
| 115 |
+
started_at = Column(DateTime(timezone=True), nullable=True)
|
| 116 |
+
completed_at = Column(DateTime(timezone=True), nullable=True)
|
| 117 |
+
|
| 118 |
+
def __repr__(self):
|
| 119 |
+
return f"<GeminiJob(job_id={self.job_id}, type={self.job_type}, status={self.status})>"
|
core/schemas.py
CHANGED
|
@@ -1,12 +1,46 @@
|
|
| 1 |
from pydantic import BaseModel, EmailStr, Field
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
# Pydantic Models
|
| 4 |
class CheckRegistrationRequest(BaseModel):
|
|
|
|
| 5 |
user_id: str = Field(..., min_length=1, description="Temporary user ID from frontend")
|
| 6 |
|
| 7 |
-
class RegisterRequest(BaseModel):
|
| 8 |
-
user_id: str = Field(..., min_length=1, description="Temporary user ID from frontend")
|
| 9 |
-
email: EmailStr = Field(..., description="User email address")
|
| 10 |
|
| 11 |
-
class
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from pydantic import BaseModel, EmailStr, Field
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
# Pydantic Models for Google OAuth Authentication
|
| 5 |
|
|
|
|
| 6 |
class CheckRegistrationRequest(BaseModel):
|
| 7 |
+
"""Check if a user_id has completed registration."""
|
| 8 |
user_id: str = Field(..., min_length=1, description="Temporary user ID from frontend")
|
| 9 |
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
+
class GoogleAuthRequest(BaseModel):
|
| 12 |
+
"""Request with Google ID token from frontend Sign-In."""
|
| 13 |
+
id_token: str = Field(..., min_length=1, description="Google ID token from Sign-In")
|
| 14 |
+
temp_user_id: Optional[str] = Field(None, description="Optional temp user ID for linking")
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class AuthResponse(BaseModel):
|
| 18 |
+
"""Response after successful Google authentication."""
|
| 19 |
+
success: bool
|
| 20 |
+
access_token: str
|
| 21 |
+
user_id: str
|
| 22 |
+
email: str
|
| 23 |
+
name: Optional[str] = None
|
| 24 |
+
credits: int
|
| 25 |
+
is_new_user: bool
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class UserInfoResponse(BaseModel):
|
| 29 |
+
"""Response containing current user information."""
|
| 30 |
+
user_id: str
|
| 31 |
+
email: str
|
| 32 |
+
name: Optional[str] = None
|
| 33 |
+
credits: int
|
| 34 |
+
profile_picture: Optional[str] = None
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class TokenRefreshRequest(BaseModel):
|
| 38 |
+
"""Request to refresh an access token."""
|
| 39 |
+
token: str = Field(..., description="Current access token to refresh")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class TokenRefreshResponse(BaseModel):
|
| 43 |
+
"""Response with refreshed access token."""
|
| 44 |
+
success: bool
|
| 45 |
+
access_token: str
|
| 46 |
+
|
core/security.py
CHANGED
|
@@ -1,25 +1,32 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import bcrypt
|
| 3 |
|
|
|
|
| 4 |
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
| 5 |
"""
|
| 6 |
-
Verify a password against a hash.
|
|
|
|
|
|
|
|
|
|
| 7 |
"""
|
| 8 |
if isinstance(hashed_password, str):
|
| 9 |
hashed_password = hashed_password.encode('utf-8')
|
| 10 |
return bcrypt.checkpw(plain_password.encode('utf-8'), hashed_password)
|
| 11 |
|
|
|
|
| 12 |
def get_password_hash(password: str) -> str:
|
| 13 |
"""
|
| 14 |
Hash a password using bcrypt.
|
|
|
|
|
|
|
|
|
|
| 15 |
"""
|
| 16 |
-
# rounds=12 as per spec
|
| 17 |
salt = bcrypt.gensalt(rounds=12)
|
| 18 |
hashed = bcrypt.hashpw(password.encode('utf-8'), salt)
|
| 19 |
return hashed.decode('utf-8')
|
| 20 |
|
| 21 |
-
def generate_secret_key() -> str:
|
| 22 |
-
"""
|
| 23 |
-
Generate a secure secret key starting with 'sk_'.
|
| 24 |
-
"""
|
| 25 |
-
return "sk_" + secrets.token_urlsafe(32)
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Core Security Utilities
|
| 3 |
+
|
| 4 |
+
Note: Secret key authentication has been replaced with Google OAuth.
|
| 5 |
+
The bcrypt functions below are kept for potential future use (e.g., admin passwords).
|
| 6 |
+
"""
|
| 7 |
import bcrypt
|
| 8 |
|
| 9 |
+
|
| 10 |
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
| 11 |
"""
|
| 12 |
+
Verify a password against a bcrypt hash.
|
| 13 |
+
|
| 14 |
+
Note: This is no longer used for user authentication (moved to Google OAuth).
|
| 15 |
+
Kept for potential admin/internal use cases.
|
| 16 |
"""
|
| 17 |
if isinstance(hashed_password, str):
|
| 18 |
hashed_password = hashed_password.encode('utf-8')
|
| 19 |
return bcrypt.checkpw(plain_password.encode('utf-8'), hashed_password)
|
| 20 |
|
| 21 |
+
|
| 22 |
def get_password_hash(password: str) -> str:
|
| 23 |
"""
|
| 24 |
Hash a password using bcrypt.
|
| 25 |
+
|
| 26 |
+
Note: This is no longer used for user authentication (moved to Google OAuth).
|
| 27 |
+
Kept for potential admin/internal use cases.
|
| 28 |
"""
|
|
|
|
| 29 |
salt = bcrypt.gensalt(rounds=12)
|
| 30 |
hashed = bcrypt.hashpw(password.encode('utf-8'), salt)
|
| 31 |
return hashed.decode('utf-8')
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dependencies.py
CHANGED
|
@@ -9,7 +9,12 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|
| 9 |
|
| 10 |
from core.database import get_db
|
| 11 |
from core.models import User, RateLimit
|
| 12 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
logger = logging.getLogger(__name__)
|
| 15 |
|
|
@@ -63,57 +68,107 @@ async def check_rate_limit(
|
|
| 63 |
await db.commit()
|
| 64 |
return True
|
| 65 |
|
| 66 |
-
|
|
|
|
| 67 |
req: Request,
|
| 68 |
db: AsyncSession = Depends(get_db)
|
| 69 |
) -> User:
|
| 70 |
"""
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
"""
|
| 73 |
-
|
| 74 |
-
|
|
|
|
| 75 |
raise HTTPException(
|
| 76 |
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 77 |
-
detail="Missing
|
|
|
|
| 78 |
)
|
| 79 |
|
| 80 |
-
|
| 81 |
-
if not secret_key.startswith("sk_"):
|
| 82 |
raise HTTPException(
|
| 83 |
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 84 |
-
detail="Invalid
|
|
|
|
| 85 |
)
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
result = await db.execute(query)
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
for user in users:
|
| 94 |
-
if verify_password(secret_key, user.secret_key_hash):
|
| 95 |
-
valid_user = user
|
| 96 |
-
break
|
| 97 |
-
|
| 98 |
-
if not valid_user:
|
| 99 |
raise HTTPException(
|
| 100 |
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 101 |
-
detail="
|
| 102 |
)
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
raise HTTPException(
|
| 107 |
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
| 108 |
-
detail="Insufficient credits"
|
| 109 |
)
|
| 110 |
-
|
| 111 |
# Deduct credit
|
| 112 |
-
|
| 113 |
-
|
| 114 |
await db.commit()
|
| 115 |
|
| 116 |
-
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
async def get_geolocation(ip_address: str) -> Tuple[Optional[str], Optional[str]]:
|
| 119 |
"""
|
|
|
|
| 9 |
|
| 10 |
from core.database import get_db
|
| 11 |
from core.models import User, RateLimit
|
| 12 |
+
from services.jwt_service import (
|
| 13 |
+
verify_access_token,
|
| 14 |
+
TokenExpiredError,
|
| 15 |
+
InvalidTokenError,
|
| 16 |
+
JWTError
|
| 17 |
+
)
|
| 18 |
|
| 19 |
logger = logging.getLogger(__name__)
|
| 20 |
|
|
|
|
| 68 |
await db.commit()
|
| 69 |
return True
|
| 70 |
|
| 71 |
+
|
| 72 |
+
async def get_current_user(
|
| 73 |
req: Request,
|
| 74 |
db: AsyncSession = Depends(get_db)
|
| 75 |
) -> User:
|
| 76 |
"""
|
| 77 |
+
Extract and verify JWT from Authorization header.
|
| 78 |
+
Returns the authenticated user.
|
| 79 |
+
|
| 80 |
+
Usage:
|
| 81 |
+
@router.get("/protected")
|
| 82 |
+
async def protected_route(user: User = Depends(get_current_user)):
|
| 83 |
+
return {"user_id": user.user_id}
|
| 84 |
"""
|
| 85 |
+
auth_header = req.headers.get("Authorization")
|
| 86 |
+
|
| 87 |
+
if not auth_header:
|
| 88 |
raise HTTPException(
|
| 89 |
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 90 |
+
detail="Missing Authorization header",
|
| 91 |
+
headers={"WWW-Authenticate": "Bearer"}
|
| 92 |
)
|
| 93 |
|
| 94 |
+
if not auth_header.startswith("Bearer "):
|
|
|
|
| 95 |
raise HTTPException(
|
| 96 |
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 97 |
+
detail="Invalid Authorization header format. Use: Bearer <token>",
|
| 98 |
+
headers={"WWW-Authenticate": "Bearer"}
|
| 99 |
)
|
| 100 |
+
|
| 101 |
+
token = auth_header.split(" ", 1)[1]
|
| 102 |
+
|
| 103 |
+
try:
|
| 104 |
+
payload = verify_access_token(token)
|
| 105 |
+
except TokenExpiredError:
|
| 106 |
+
raise HTTPException(
|
| 107 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 108 |
+
detail="Token has expired. Please sign in again.",
|
| 109 |
+
headers={"WWW-Authenticate": "Bearer"}
|
| 110 |
+
)
|
| 111 |
+
except InvalidTokenError as e:
|
| 112 |
+
raise HTTPException(
|
| 113 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 114 |
+
detail=f"Invalid token: {str(e)}",
|
| 115 |
+
headers={"WWW-Authenticate": "Bearer"}
|
| 116 |
+
)
|
| 117 |
+
except JWTError as e:
|
| 118 |
+
raise HTTPException(
|
| 119 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 120 |
+
detail=f"Authentication error: {str(e)}",
|
| 121 |
+
headers={"WWW-Authenticate": "Bearer"}
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# Get user from DB
|
| 125 |
+
query = select(User).where(
|
| 126 |
+
User.user_id == payload.user_id,
|
| 127 |
+
User.is_active == True
|
| 128 |
+
)
|
| 129 |
result = await db.execute(query)
|
| 130 |
+
user = result.scalar_one_or_none()
|
| 131 |
+
|
| 132 |
+
if not user:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
raise HTTPException(
|
| 134 |
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 135 |
+
detail="User not found or inactive"
|
| 136 |
)
|
| 137 |
+
|
| 138 |
+
return user
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
async def verify_credits(
|
| 142 |
+
user: User = Depends(get_current_user),
|
| 143 |
+
db: AsyncSession = Depends(get_db)
|
| 144 |
+
) -> User:
|
| 145 |
+
"""
|
| 146 |
+
Verify user has credits and deduct one.
|
| 147 |
+
|
| 148 |
+
This dependency first authenticates the user via JWT,
|
| 149 |
+
then checks and deducts credits.
|
| 150 |
+
|
| 151 |
+
Usage:
|
| 152 |
+
@router.post("/api-endpoint")
|
| 153 |
+
async def api_endpoint(user: User = Depends(verify_credits)):
|
| 154 |
+
# User is authenticated and has 1 credit deducted
|
| 155 |
+
return {"credits_remaining": user.credits}
|
| 156 |
+
"""
|
| 157 |
+
if user.credits <= 0:
|
| 158 |
raise HTTPException(
|
| 159 |
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
| 160 |
+
detail="Insufficient credits. Please purchase more credits."
|
| 161 |
)
|
| 162 |
+
|
| 163 |
# Deduct credit
|
| 164 |
+
user.credits -= 1
|
| 165 |
+
user.last_used_at = datetime.utcnow()
|
| 166 |
await db.commit()
|
| 167 |
|
| 168 |
+
logger.debug(f"Deducted 1 credit from user {user.user_id}. Remaining: {user.credits}")
|
| 169 |
+
|
| 170 |
+
return user
|
| 171 |
+
|
| 172 |
|
| 173 |
async def get_geolocation(ip_address: str) -> Tuple[Optional[str], Optional[str]]:
|
| 174 |
"""
|
docs/CLIENT_INTEGRATION.md
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Google Sign-In Client Integration Guide
|
| 2 |
+
|
| 3 |
+
Quick guide for frontend developers to integrate Google Sign-In with the APIGateway.
|
| 4 |
+
|
| 5 |
+
## Setup
|
| 6 |
+
|
| 7 |
+
### 1. Get Your Google Client ID
|
| 8 |
+
|
| 9 |
+
Use the same `GOOGLE_CLIENT_ID` that's configured on the backend.
|
| 10 |
+
|
| 11 |
+
### 2. Add Google Identity Services Script
|
| 12 |
+
|
| 13 |
+
```html
|
| 14 |
+
<script src="https://accounts.google.com/gsi/client" async defer></script>
|
| 15 |
+
```
|
| 16 |
+
|
| 17 |
+
---
|
| 18 |
+
|
| 19 |
+
## Option A: One-Tap / Button Sign-In (Recommended)
|
| 20 |
+
|
| 21 |
+
```html
|
| 22 |
+
<!DOCTYPE html>
|
| 23 |
+
<html>
|
| 24 |
+
<head>
|
| 25 |
+
<script src="https://accounts.google.com/gsi/client" async defer></script>
|
| 26 |
+
</head>
|
| 27 |
+
<body>
|
| 28 |
+
<!-- Google One Tap prompt -->
|
| 29 |
+
<div id="g_id_onload"
|
| 30 |
+
data-client_id="YOUR_GOOGLE_CLIENT_ID.apps.googleusercontent.com"
|
| 31 |
+
data-callback="handleGoogleSignIn"
|
| 32 |
+
data-auto_prompt="false">
|
| 33 |
+
</div>
|
| 34 |
+
|
| 35 |
+
<!-- Sign In Button -->
|
| 36 |
+
<div class="g_id_signin"
|
| 37 |
+
data-type="standard"
|
| 38 |
+
data-size="large"
|
| 39 |
+
data-theme="outline"
|
| 40 |
+
data-text="sign_in_with"
|
| 41 |
+
data-shape="rectangular"
|
| 42 |
+
data-logo_alignment="left">
|
| 43 |
+
</div>
|
| 44 |
+
|
| 45 |
+
<script>
|
| 46 |
+
const API_BASE = 'https://your-api-gateway.com'; // Change this
|
| 47 |
+
|
| 48 |
+
async function handleGoogleSignIn(response) {
|
| 49 |
+
try {
|
| 50 |
+
// Send Google token to your backend
|
| 51 |
+
const res = await fetch(`${API_BASE}/auth/google`, {
|
| 52 |
+
method: 'POST',
|
| 53 |
+
headers: { 'Content-Type': 'application/json' },
|
| 54 |
+
body: JSON.stringify({
|
| 55 |
+
id_token: response.credential,
|
| 56 |
+
temp_user_id: localStorage.getItem('temp_user_id') // optional
|
| 57 |
+
})
|
| 58 |
+
});
|
| 59 |
+
|
| 60 |
+
const data = await res.json();
|
| 61 |
+
|
| 62 |
+
if (data.success) {
|
| 63 |
+
// Store the access token
|
| 64 |
+
localStorage.setItem('access_token', data.access_token);
|
| 65 |
+
localStorage.setItem('user', JSON.stringify({
|
| 66 |
+
user_id: data.user_id,
|
| 67 |
+
email: data.email,
|
| 68 |
+
name: data.name,
|
| 69 |
+
credits: data.credits
|
| 70 |
+
}));
|
| 71 |
+
|
| 72 |
+
console.log('Signed in!', data.is_new_user ? 'New user' : 'Existing user');
|
| 73 |
+
// Redirect or update UI
|
| 74 |
+
window.location.reload();
|
| 75 |
+
} else {
|
| 76 |
+
alert('Sign in failed: ' + (data.detail || 'Unknown error'));
|
| 77 |
+
}
|
| 78 |
+
} catch (error) {
|
| 79 |
+
console.error('Sign in error:', error);
|
| 80 |
+
alert('Sign in failed. Please try again.');
|
| 81 |
+
}
|
| 82 |
+
}
|
| 83 |
+
</script>
|
| 84 |
+
</body>
|
| 85 |
+
</html>
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
---
|
| 89 |
+
|
| 90 |
+
## Option B: Programmatic Sign-In
|
| 91 |
+
|
| 92 |
+
```javascript
|
| 93 |
+
const API_BASE = 'https://your-api-gateway.com';
|
| 94 |
+
|
| 95 |
+
// Initialize Google Sign-In
|
| 96 |
+
function initGoogleSignIn(clientId) {
|
| 97 |
+
google.accounts.id.initialize({
|
| 98 |
+
client_id: clientId,
|
| 99 |
+
callback: handleGoogleSignIn
|
| 100 |
+
});
|
| 101 |
+
|
| 102 |
+
// Render button in a container
|
| 103 |
+
google.accounts.id.renderButton(
|
| 104 |
+
document.getElementById('google-signin-btn'),
|
| 105 |
+
{ theme: 'outline', size: 'large', text: 'signin_with' }
|
| 106 |
+
);
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
// Handle the response
|
| 110 |
+
async function handleGoogleSignIn(response) {
|
| 111 |
+
const res = await fetch(`${API_BASE}/auth/google`, {
|
| 112 |
+
method: 'POST',
|
| 113 |
+
headers: { 'Content-Type': 'application/json' },
|
| 114 |
+
body: JSON.stringify({ id_token: response.credential })
|
| 115 |
+
});
|
| 116 |
+
|
| 117 |
+
const data = await res.json();
|
| 118 |
+
if (data.success) {
|
| 119 |
+
localStorage.setItem('access_token', data.access_token);
|
| 120 |
+
}
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
// Call on page load
|
| 124 |
+
window.onload = () => initGoogleSignIn('YOUR_CLIENT_ID.apps.googleusercontent.com');
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
---
|
| 128 |
+
|
| 129 |
+
## Making API Calls
|
| 130 |
+
|
| 131 |
+
After sign-in, use the access token for all API calls:
|
| 132 |
+
|
| 133 |
+
```javascript
|
| 134 |
+
const API_BASE = 'https://your-api-gateway.com';
|
| 135 |
+
|
| 136 |
+
async function apiCall(endpoint, options = {}) {
|
| 137 |
+
const token = localStorage.getItem('access_token');
|
| 138 |
+
|
| 139 |
+
if (!token) {
|
| 140 |
+
throw new Error('Not signed in');
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
const response = await fetch(`${API_BASE}${endpoint}`, {
|
| 144 |
+
...options,
|
| 145 |
+
headers: {
|
| 146 |
+
'Authorization': `Bearer ${token}`,
|
| 147 |
+
'Content-Type': 'application/json',
|
| 148 |
+
...options.headers
|
| 149 |
+
}
|
| 150 |
+
});
|
| 151 |
+
|
| 152 |
+
if (response.status === 401) {
|
| 153 |
+
// Token expired - need to sign in again
|
| 154 |
+
localStorage.removeItem('access_token');
|
| 155 |
+
window.location.href = '/login';
|
| 156 |
+
return;
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
return response.json();
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
// Examples
|
| 163 |
+
async function getCurrentUser() {
|
| 164 |
+
return apiCall('/auth/me');
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
async function generateText(prompt) {
|
| 168 |
+
return apiCall('/gemini/generate-text', {
|
| 169 |
+
method: 'POST',
|
| 170 |
+
body: JSON.stringify({ prompt })
|
| 171 |
+
});
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
async function checkJobStatus(jobId) {
|
| 175 |
+
return apiCall(`/gemini/job/${jobId}`);
|
| 176 |
+
}
|
| 177 |
+
```
|
| 178 |
+
|
| 179 |
+
---
|
| 180 |
+
|
| 181 |
+
## API Endpoints Reference
|
| 182 |
+
|
| 183 |
+
| Endpoint | Method | Auth Required | Description |
|
| 184 |
+
|----------|--------|---------------|-------------|
|
| 185 |
+
| `/auth/google` | POST | ❌ | Sign in with Google token |
|
| 186 |
+
| `/auth/me` | GET | ✅ | Get current user info |
|
| 187 |
+
| `/auth/refresh` | POST | ❌ | Refresh access token |
|
| 188 |
+
| `/gemini/generate-text` | POST | ✅ | Generate text (costs 1 credit) |
|
| 189 |
+
| `/gemini/generate-video` | POST | ✅ | Generate video (costs 1 credit) |
|
| 190 |
+
| `/gemini/job/{job_id}` | GET | ✅ | Check job status |
|
| 191 |
+
|
| 192 |
+
---
|
| 193 |
+
|
| 194 |
+
## Sign Out
|
| 195 |
+
|
| 196 |
+
```javascript
|
| 197 |
+
function signOut() {
|
| 198 |
+
localStorage.removeItem('access_token');
|
| 199 |
+
localStorage.removeItem('user');
|
| 200 |
+
|
| 201 |
+
// Optionally call logout endpoint for audit
|
| 202 |
+
fetch(`${API_BASE}/auth/logout`, {
|
| 203 |
+
method: 'POST',
|
| 204 |
+
headers: { 'Authorization': `Bearer ${token}` }
|
| 205 |
+
});
|
| 206 |
+
|
| 207 |
+
// Revoke Google session
|
| 208 |
+
google.accounts.id.disableAutoSelect();
|
| 209 |
+
|
| 210 |
+
window.location.href = '/';
|
| 211 |
+
}
|
| 212 |
+
```
|
| 213 |
+
|
| 214 |
+
---
|
| 215 |
+
|
| 216 |
+
## Error Handling
|
| 217 |
+
|
| 218 |
+
| Status | Meaning | Action |
|
| 219 |
+
|--------|---------|--------|
|
| 220 |
+
| 401 | Token expired/invalid | Redirect to sign-in |
|
| 221 |
+
| 402 | Insufficient credits | Show "buy credits" prompt |
|
| 222 |
+
| 429 | Rate limited | Wait and retry |
|
| 223 |
+
|
| 224 |
+
```javascript
|
| 225 |
+
async function apiCallWithErrorHandling(endpoint, options) {
|
| 226 |
+
const response = await apiCall(endpoint, options);
|
| 227 |
+
|
| 228 |
+
if (response.status === 402) {
|
| 229 |
+
alert('You have run out of credits!');
|
| 230 |
+
return null;
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
return response;
|
| 234 |
+
}
|
| 235 |
+
```
|
| 236 |
+
|
| 237 |
+
---
|
| 238 |
+
|
| 239 |
+
## React Example
|
| 240 |
+
|
| 241 |
+
```jsx
|
| 242 |
+
import { useEffect, useState } from 'react';
|
| 243 |
+
|
| 244 |
+
function App() {
|
| 245 |
+
const [user, setUser] = useState(null);
|
| 246 |
+
|
| 247 |
+
useEffect(() => {
|
| 248 |
+
// Check if already signed in
|
| 249 |
+
const token = localStorage.getItem('access_token');
|
| 250 |
+
if (token) {
|
| 251 |
+
fetch('/auth/me', {
|
| 252 |
+
headers: { 'Authorization': `Bearer ${token}` }
|
| 253 |
+
})
|
| 254 |
+
.then(r => r.json())
|
| 255 |
+
.then(setUser)
|
| 256 |
+
.catch(() => localStorage.removeItem('access_token'));
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
// Initialize Google Sign-In
|
| 260 |
+
window.handleGoogleSignIn = async (response) => {
|
| 261 |
+
const res = await fetch('/auth/google', {
|
| 262 |
+
method: 'POST',
|
| 263 |
+
headers: { 'Content-Type': 'application/json' },
|
| 264 |
+
body: JSON.stringify({ id_token: response.credential })
|
| 265 |
+
});
|
| 266 |
+
const data = await res.json();
|
| 267 |
+
if (data.success) {
|
| 268 |
+
localStorage.setItem('access_token', data.access_token);
|
| 269 |
+
setUser(data);
|
| 270 |
+
}
|
| 271 |
+
};
|
| 272 |
+
}, []);
|
| 273 |
+
|
| 274 |
+
if (!user) {
|
| 275 |
+
return (
|
| 276 |
+
<div>
|
| 277 |
+
<div id="g_id_onload"
|
| 278 |
+
data-client_id="YOUR_CLIENT_ID"
|
| 279 |
+
data-callback="handleGoogleSignIn" />
|
| 280 |
+
<div className="g_id_signin" data-type="standard" />
|
| 281 |
+
</div>
|
| 282 |
+
);
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
return <div>Welcome, {user.name}! Credits: {user.credits}</div>;
|
| 286 |
+
}
|
| 287 |
+
```
|
generate_jwt_secret.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Generate JWT Secret Key
|
| 4 |
+
|
| 5 |
+
This script generates a cryptographically secure secret key for JWT signing.
|
| 6 |
+
Run this locally and add the generated key to your .env file.
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
python generate_jwt_secret.py
|
| 10 |
+
|
| 11 |
+
# Or with custom length
|
| 12 |
+
python generate_jwt_secret.py --length 128
|
| 13 |
+
|
| 14 |
+
Output:
|
| 15 |
+
Prints the secret key and instructions for adding it to your environment.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
import secrets
|
| 20 |
+
import sys
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def generate_secret(length: int = 64) -> str:
|
| 24 |
+
"""
|
| 25 |
+
Generate a cryptographically secure URL-safe secret.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
length: Number of bytes for the secret (default: 64).
|
| 29 |
+
The actual string length will be ~1.3x this due to base64 encoding.
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
str: URL-safe base64 encoded secret.
|
| 33 |
+
"""
|
| 34 |
+
return secrets.token_urlsafe(length)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def main():
|
| 38 |
+
parser = argparse.ArgumentParser(
|
| 39 |
+
description="Generate a secure JWT secret key",
|
| 40 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 41 |
+
epilog="""
|
| 42 |
+
Examples:
|
| 43 |
+
python generate_jwt_secret.py
|
| 44 |
+
python generate_jwt_secret.py --length 128
|
| 45 |
+
python generate_jwt_secret.py --format docker
|
| 46 |
+
"""
|
| 47 |
+
)
|
| 48 |
+
parser.add_argument(
|
| 49 |
+
"--length", "-l",
|
| 50 |
+
type=int,
|
| 51 |
+
default=64,
|
| 52 |
+
help="Number of bytes for the secret (default: 64)"
|
| 53 |
+
)
|
| 54 |
+
parser.add_argument(
|
| 55 |
+
"--format", "-f",
|
| 56 |
+
choices=["env", "docker", "export", "raw"],
|
| 57 |
+
default="env",
|
| 58 |
+
help="Output format (default: env)"
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
args = parser.parse_args()
|
| 62 |
+
|
| 63 |
+
if args.length < 32:
|
| 64 |
+
print("Warning: Secret length should be at least 32 bytes for security.", file=sys.stderr)
|
| 65 |
+
|
| 66 |
+
secret = generate_secret(args.length)
|
| 67 |
+
|
| 68 |
+
print("\n" + "=" * 60)
|
| 69 |
+
print("🔐 Generated JWT Secret Key")
|
| 70 |
+
print("=" * 60)
|
| 71 |
+
|
| 72 |
+
if args.format == "raw":
|
| 73 |
+
print(secret)
|
| 74 |
+
elif args.format == "env":
|
| 75 |
+
print(f"\nAdd this line to your .env file:\n")
|
| 76 |
+
print(f"JWT_SECRET={secret}")
|
| 77 |
+
elif args.format == "docker":
|
| 78 |
+
print(f"\nAdd this to your docker-compose.yml environment:\n")
|
| 79 |
+
print(f" - JWT_SECRET={secret}")
|
| 80 |
+
elif args.format == "export":
|
| 81 |
+
print(f"\nRun this command to set the environment variable:\n")
|
| 82 |
+
print(f"export JWT_SECRET='{secret}'")
|
| 83 |
+
|
| 84 |
+
print("\n" + "-" * 60)
|
| 85 |
+
print("⚠️ IMPORTANT SECURITY NOTES:")
|
| 86 |
+
print("-" * 60)
|
| 87 |
+
print("• Keep this secret confidential - never commit it to git")
|
| 88 |
+
print("• Use different secrets for development and production")
|
| 89 |
+
print("• If compromised, all existing tokens become invalid")
|
| 90 |
+
print("• Store securely (e.g., secrets manager, encrypted env)")
|
| 91 |
+
print("=" * 60 + "\n")
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
if __name__ == "__main__":
|
| 95 |
+
main()
|
requirements.txt
CHANGED
|
@@ -13,3 +13,6 @@ python-dotenv>=1.0.0
|
|
| 13 |
google-api-python-client>=2.0.0
|
| 14 |
google-auth-oauthlib>=1.0.0
|
| 15 |
google-auth-httplib2>=0.1.0
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
google-api-python-client>=2.0.0
|
| 14 |
google-auth-oauthlib>=1.0.0
|
| 15 |
google-auth-httplib2>=0.1.0
|
| 16 |
+
google-genai>=1.0.0
|
| 17 |
+
PyJWT>=2.8.0
|
| 18 |
+
|
routers/auth.py
CHANGED
|
@@ -1,21 +1,48 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
from fastapi.responses import JSONResponse
|
| 3 |
from sqlalchemy.ext.asyncio import AsyncSession
|
| 4 |
from sqlalchemy import select
|
| 5 |
from datetime import datetime
|
| 6 |
import uuid
|
|
|
|
| 7 |
|
| 8 |
from core.database import get_db
|
| 9 |
from core.models import User, AuditLog
|
| 10 |
-
from core.schemas import
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
from services.drive_service import DriveService
|
| 15 |
|
|
|
|
|
|
|
| 16 |
router = APIRouter(prefix="/auth", tags=["auth"])
|
| 17 |
drive_service = DriveService()
|
| 18 |
|
|
|
|
| 19 |
@router.post("/check-registration")
|
| 20 |
async def check_registration(
|
| 21 |
request: CheckRegistrationRequest,
|
|
@@ -24,6 +51,7 @@ async def check_registration(
|
|
| 24 |
):
|
| 25 |
"""
|
| 26 |
Check if a temporary user_id has completed registration.
|
|
|
|
| 27 |
"""
|
| 28 |
# Rate Limit: 10 requests per minute per IP
|
| 29 |
ip = req.client.host
|
|
@@ -37,227 +65,209 @@ async def check_registration(
|
|
| 37 |
return {"is_registered": user is not None}
|
| 38 |
|
| 39 |
|
| 40 |
-
@router.post("/
|
| 41 |
-
async def
|
| 42 |
-
request:
|
| 43 |
req: Request,
|
| 44 |
background_tasks: BackgroundTasks,
|
| 45 |
db: AsyncSession = Depends(get_db)
|
| 46 |
):
|
| 47 |
"""
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
"""
|
| 50 |
-
# Rate Limit: 5 registrations per hour per IP
|
| 51 |
ip = req.client.host
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
result = await db.execute(query)
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
# Generate Secret Key
|
| 68 |
-
secret_key = generate_secret_key()
|
| 69 |
-
secret_key_hash = get_password_hash(secret_key)
|
| 70 |
-
backend_user_id = "usr_" + str(uuid.uuid4())
|
| 71 |
-
|
| 72 |
-
# Create User
|
| 73 |
-
new_user = User(
|
| 74 |
-
user_id=backend_user_id,
|
| 75 |
-
temp_user_id=request.user_id,
|
| 76 |
-
email=request.email,
|
| 77 |
-
secret_key_hash=secret_key_hash,
|
| 78 |
-
credits=100
|
| 79 |
-
)
|
| 80 |
-
db.add(new_user)
|
| 81 |
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
audit_log = AuditLog(
|
| 84 |
-
user_id=
|
| 85 |
-
action="
|
| 86 |
ip_address=ip,
|
| 87 |
status="success"
|
| 88 |
)
|
| 89 |
db.add(audit_log)
|
| 90 |
-
|
| 91 |
await db.commit()
|
| 92 |
-
|
| 93 |
-
#
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
Your secret key is: {secret_key}
|
| 97 |
-
|
| 98 |
-
Please save this key securely.
|
| 99 |
-
You'll need it to access your credits.
|
| 100 |
-
|
| 101 |
-
If you lose this key, use the 'Forgot Key'
|
| 102 |
-
option with this email address.
|
| 103 |
-
|
| 104 |
-
Credits: 100
|
| 105 |
-
Valid from: {datetime.now().strftime('%Y-%m-%d')}
|
| 106 |
-
|
| 107 |
-
Do not share this key with anyone."""
|
| 108 |
-
|
| 109 |
-
background_tasks.add_task(
|
| 110 |
-
send_email,
|
| 111 |
-
request.email,
|
| 112 |
-
"Your Secret Key - Credit System",
|
| 113 |
-
email_body
|
| 114 |
-
)
|
| 115 |
|
| 116 |
# Sync DB to Drive (Async)
|
| 117 |
background_tasks.add_task(drive_service.upload_db)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
-
return {"success": True, "message": "Registration successful. Check your email."}
|
| 120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
|
| 122 |
-
|
| 123 |
-
|
|
|
|
|
|
|
| 124 |
req: Request,
|
| 125 |
-
background_tasks: BackgroundTasks,
|
| 126 |
-
x_secret_key: str = Header(..., alias="X-Secret-Key"),
|
| 127 |
db: AsyncSession = Depends(get_db)
|
| 128 |
):
|
| 129 |
"""
|
| 130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
"""
|
| 132 |
-
# Rate Limit: 20 validations per hour per IP
|
| 133 |
ip = req.client.host
|
| 134 |
-
if not await check_rate_limit(db, ip, "/auth/validate", 20, 60):
|
| 135 |
-
raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail="Too many validation attempts")
|
| 136 |
-
|
| 137 |
-
query = select(User).where(User.is_active == True)
|
| 138 |
-
result = await db.execute(query)
|
| 139 |
-
users = result.scalars().all()
|
| 140 |
-
|
| 141 |
-
valid_user = None
|
| 142 |
-
for user in users:
|
| 143 |
-
if verify_password(x_secret_key, user.secret_key_hash):
|
| 144 |
-
valid_user = user
|
| 145 |
-
break
|
| 146 |
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
audit_log = AuditLog(
|
| 153 |
-
user_id=valid_user.user_id,
|
| 154 |
-
action="validate",
|
| 155 |
-
ip_address=ip,
|
| 156 |
-
status="success"
|
| 157 |
)
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
background_tasks.add_task(drive_service.upload_db)
|
| 163 |
|
| 164 |
-
return
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
"credits": valid_user.credits,
|
| 168 |
-
"message": "Valid key"
|
| 169 |
-
}
|
| 170 |
-
else:
|
| 171 |
-
# Log Audit
|
| 172 |
-
audit_log = AuditLog(
|
| 173 |
-
user_id=None,
|
| 174 |
-
action="validate",
|
| 175 |
-
ip_address=ip,
|
| 176 |
-
status="failed",
|
| 177 |
-
error_message="Invalid key"
|
| 178 |
)
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
return JSONResponse(
|
| 183 |
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 184 |
-
|
| 185 |
)
|
| 186 |
|
| 187 |
|
| 188 |
-
@router.post("/
|
| 189 |
-
async def
|
| 190 |
-
request: ResetRequest,
|
| 191 |
req: Request,
|
| 192 |
background_tasks: BackgroundTasks,
|
|
|
|
| 193 |
db: AsyncSession = Depends(get_db)
|
| 194 |
):
|
| 195 |
"""
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail="Too many reset attempts")
|
| 201 |
-
|
| 202 |
-
query = select(User).where(User.email == request.email)
|
| 203 |
-
result = await db.execute(query)
|
| 204 |
-
user = result.scalar_one_or_none()
|
| 205 |
|
|
|
|
|
|
|
|
|
|
| 206 |
ip = req.client.host
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
)
|
| 223 |
-
db.add(audit_log)
|
| 224 |
-
await db.commit()
|
| 225 |
-
|
| 226 |
-
# Send Email
|
| 227 |
-
email_body = f"""Hello,
|
| 228 |
-
|
| 229 |
-
You requested a secret key reset.
|
| 230 |
-
|
| 231 |
-
Your NEW secret key is: {new_secret_key}
|
| 232 |
-
|
| 233 |
-
Your old secret key is now invalid.
|
| 234 |
-
|
| 235 |
-
If you didn't request this, please
|
| 236 |
-
contact support immediately.
|
| 237 |
-
|
| 238 |
-
Current Credits: {user.credits}"""
|
| 239 |
-
|
| 240 |
-
background_tasks.add_task(
|
| 241 |
-
send_email,
|
| 242 |
-
request.email,
|
| 243 |
-
"Your New Secret Key",
|
| 244 |
-
email_body
|
| 245 |
-
)
|
| 246 |
-
|
| 247 |
-
# Sync DB to Drive (Async)
|
| 248 |
-
background_tasks.add_task(drive_service.upload_db)
|
| 249 |
-
|
| 250 |
-
else:
|
| 251 |
-
# Log Audit (failed/not found)
|
| 252 |
-
audit_log = AuditLog(
|
| 253 |
-
user_id=None,
|
| 254 |
-
action="reset",
|
| 255 |
-
ip_address=ip,
|
| 256 |
-
status="failed",
|
| 257 |
-
error_message="Email not found"
|
| 258 |
-
)
|
| 259 |
-
db.add(audit_log)
|
| 260 |
-
await db.commit()
|
| 261 |
-
|
| 262 |
-
# Always return success
|
| 263 |
-
return {"success": True, "message": "If this email is registered, reset instructions have been sent."}
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Authentication Router - Google OAuth
|
| 3 |
+
|
| 4 |
+
Endpoints for Google Sign-In authentication flow.
|
| 5 |
+
No more secret keys - users authenticate with their Google account.
|
| 6 |
+
"""
|
| 7 |
+
from fastapi import APIRouter, Depends, HTTPException, status, Request, BackgroundTasks
|
| 8 |
from fastapi.responses import JSONResponse
|
| 9 |
from sqlalchemy.ext.asyncio import AsyncSession
|
| 10 |
from sqlalchemy import select
|
| 11 |
from datetime import datetime
|
| 12 |
import uuid
|
| 13 |
+
import logging
|
| 14 |
|
| 15 |
from core.database import get_db
|
| 16 |
from core.models import User, AuditLog
|
| 17 |
+
from core.schemas import (
|
| 18 |
+
CheckRegistrationRequest,
|
| 19 |
+
GoogleAuthRequest,
|
| 20 |
+
AuthResponse,
|
| 21 |
+
UserInfoResponse,
|
| 22 |
+
TokenRefreshRequest,
|
| 23 |
+
TokenRefreshResponse
|
| 24 |
+
)
|
| 25 |
+
from services.google_auth_service import (
|
| 26 |
+
GoogleAuthService,
|
| 27 |
+
InvalidTokenError as GoogleInvalidTokenError,
|
| 28 |
+
ConfigurationError as GoogleConfigError,
|
| 29 |
+
get_google_auth_service
|
| 30 |
+
)
|
| 31 |
+
from services.jwt_service import (
|
| 32 |
+
JWTService,
|
| 33 |
+
create_access_token,
|
| 34 |
+
get_jwt_service,
|
| 35 |
+
InvalidTokenError as JWTInvalidTokenError
|
| 36 |
+
)
|
| 37 |
+
from dependencies import check_rate_limit, get_current_user
|
| 38 |
from services.drive_service import DriveService
|
| 39 |
|
| 40 |
+
logger = logging.getLogger(__name__)
|
| 41 |
+
|
| 42 |
router = APIRouter(prefix="/auth", tags=["auth"])
|
| 43 |
drive_service = DriveService()
|
| 44 |
|
| 45 |
+
|
| 46 |
@router.post("/check-registration")
|
| 47 |
async def check_registration(
|
| 48 |
request: CheckRegistrationRequest,
|
|
|
|
| 51 |
):
|
| 52 |
"""
|
| 53 |
Check if a temporary user_id has completed registration.
|
| 54 |
+
Useful for frontend to check if user needs to sign in.
|
| 55 |
"""
|
| 56 |
# Rate Limit: 10 requests per minute per IP
|
| 57 |
ip = req.client.host
|
|
|
|
| 65 |
return {"is_registered": user is not None}
|
| 66 |
|
| 67 |
|
| 68 |
+
@router.post("/google", response_model=AuthResponse)
|
| 69 |
+
async def google_auth(
|
| 70 |
+
request: GoogleAuthRequest,
|
| 71 |
req: Request,
|
| 72 |
background_tasks: BackgroundTasks,
|
| 73 |
db: AsyncSession = Depends(get_db)
|
| 74 |
):
|
| 75 |
"""
|
| 76 |
+
Authenticate with Google ID token.
|
| 77 |
+
|
| 78 |
+
Frontend flow:
|
| 79 |
+
1. User clicks "Sign in with Google" button
|
| 80 |
+
2. Google returns an ID token
|
| 81 |
+
3. Frontend sends that token to this endpoint
|
| 82 |
+
4. We verify it with Google and issue our own JWT
|
| 83 |
+
|
| 84 |
+
Creates new user or returns existing user.
|
| 85 |
+
Existing users matched by email.
|
| 86 |
"""
|
|
|
|
| 87 |
ip = req.client.host
|
| 88 |
+
|
| 89 |
+
# Rate Limit: 10 attempts per minute per IP
|
| 90 |
+
if not await check_rate_limit(db, ip, "/auth/google", 10, 1):
|
| 91 |
+
raise HTTPException(
|
| 92 |
+
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
| 93 |
+
detail="Too many authentication attempts"
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
# Verify Google token
|
| 97 |
+
try:
|
| 98 |
+
google_service = get_google_auth_service()
|
| 99 |
+
google_info = google_service.verify_token(request.id_token)
|
| 100 |
+
except GoogleConfigError as e:
|
| 101 |
+
logger.error(f"Google Auth not configured: {e}")
|
| 102 |
+
raise HTTPException(
|
| 103 |
+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
| 104 |
+
detail="Google authentication is not configured"
|
| 105 |
+
)
|
| 106 |
+
except GoogleInvalidTokenError as e:
|
| 107 |
+
logger.warning(f"Invalid Google token from {ip}: {e}")
|
| 108 |
+
|
| 109 |
+
# Log failed attempt
|
| 110 |
+
audit_log = AuditLog(
|
| 111 |
+
user_id=None,
|
| 112 |
+
action="google_auth",
|
| 113 |
+
ip_address=ip,
|
| 114 |
+
status="failed",
|
| 115 |
+
error_message=str(e)
|
| 116 |
+
)
|
| 117 |
+
db.add(audit_log)
|
| 118 |
+
await db.commit()
|
| 119 |
+
|
| 120 |
+
raise HTTPException(
|
| 121 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 122 |
+
detail="Invalid Google token. Please try signing in again."
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# Check for existing user by email (preserves credits for migrated users)
|
| 126 |
+
query = select(User).where(User.email == google_info.email)
|
| 127 |
result = await db.execute(query)
|
| 128 |
+
user = result.scalar_one_or_none()
|
| 129 |
+
|
| 130 |
+
is_new_user = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
+
if user:
|
| 133 |
+
# Existing user - update Google info
|
| 134 |
+
if not user.google_id:
|
| 135 |
+
user.google_id = google_info.google_id
|
| 136 |
+
logger.info(f"Linked Google account to existing user: {user.email}")
|
| 137 |
+
|
| 138 |
+
user.name = google_info.name
|
| 139 |
+
user.profile_picture = google_info.picture
|
| 140 |
+
user.last_used_at = datetime.utcnow()
|
| 141 |
+
|
| 142 |
+
# Link temp_user_id if provided and not already set
|
| 143 |
+
if request.temp_user_id and not user.temp_user_id:
|
| 144 |
+
user.temp_user_id = request.temp_user_id
|
| 145 |
+
else:
|
| 146 |
+
# New user - create account
|
| 147 |
+
is_new_user = True
|
| 148 |
+
user = User(
|
| 149 |
+
user_id="usr_" + str(uuid.uuid4()),
|
| 150 |
+
temp_user_id=request.temp_user_id,
|
| 151 |
+
email=google_info.email,
|
| 152 |
+
google_id=google_info.google_id,
|
| 153 |
+
name=google_info.name,
|
| 154 |
+
profile_picture=google_info.picture,
|
| 155 |
+
credits=100
|
| 156 |
+
)
|
| 157 |
+
db.add(user)
|
| 158 |
+
logger.info(f"New user created via Google: {google_info.email}")
|
| 159 |
+
|
| 160 |
+
# Log successful auth
|
| 161 |
audit_log = AuditLog(
|
| 162 |
+
user_id=user.user_id,
|
| 163 |
+
action="google_auth",
|
| 164 |
ip_address=ip,
|
| 165 |
status="success"
|
| 166 |
)
|
| 167 |
db.add(audit_log)
|
|
|
|
| 168 |
await db.commit()
|
| 169 |
+
|
| 170 |
+
# Create our JWT access token
|
| 171 |
+
access_token = create_access_token(user.user_id, user.email)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
|
| 173 |
# Sync DB to Drive (Async)
|
| 174 |
background_tasks.add_task(drive_service.upload_db)
|
| 175 |
+
|
| 176 |
+
return AuthResponse(
|
| 177 |
+
success=True,
|
| 178 |
+
access_token=access_token,
|
| 179 |
+
user_id=user.user_id,
|
| 180 |
+
email=user.email,
|
| 181 |
+
name=user.name,
|
| 182 |
+
credits=user.credits,
|
| 183 |
+
is_new_user=is_new_user
|
| 184 |
+
)
|
| 185 |
|
|
|
|
| 186 |
|
| 187 |
+
@router.get("/me", response_model=UserInfoResponse)
|
| 188 |
+
async def get_current_user_info(
|
| 189 |
+
user: User = Depends(get_current_user)
|
| 190 |
+
):
|
| 191 |
+
"""
|
| 192 |
+
Get current authenticated user info.
|
| 193 |
+
|
| 194 |
+
Requires Authorization: Bearer <token> header.
|
| 195 |
+
"""
|
| 196 |
+
return UserInfoResponse(
|
| 197 |
+
user_id=user.user_id,
|
| 198 |
+
email=user.email,
|
| 199 |
+
name=user.name,
|
| 200 |
+
credits=user.credits,
|
| 201 |
+
profile_picture=user.profile_picture
|
| 202 |
+
)
|
| 203 |
|
| 204 |
+
|
| 205 |
+
@router.post("/refresh", response_model=TokenRefreshResponse)
|
| 206 |
+
async def refresh_token(
|
| 207 |
+
request: TokenRefreshRequest,
|
| 208 |
req: Request,
|
|
|
|
|
|
|
| 209 |
db: AsyncSession = Depends(get_db)
|
| 210 |
):
|
| 211 |
"""
|
| 212 |
+
Refresh an access token.
|
| 213 |
+
|
| 214 |
+
Use this when the current token is about to expire
|
| 215 |
+
(or has recently expired) to get a new one without
|
| 216 |
+
requiring the user to sign in again.
|
| 217 |
"""
|
|
|
|
| 218 |
ip = req.client.host
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
|
| 220 |
+
# Rate Limit: 5 refreshes per minute per IP
|
| 221 |
+
if not await check_rate_limit(db, ip, "/auth/refresh", 5, 1):
|
| 222 |
+
raise HTTPException(
|
| 223 |
+
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
| 224 |
+
detail="Too many refresh attempts"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
)
|
| 226 |
+
|
| 227 |
+
try:
|
| 228 |
+
jwt_service = get_jwt_service()
|
| 229 |
+
new_token = jwt_service.refresh_token(request.token)
|
|
|
|
| 230 |
|
| 231 |
+
return TokenRefreshResponse(
|
| 232 |
+
success=True,
|
| 233 |
+
access_token=new_token
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
)
|
| 235 |
+
except JWTInvalidTokenError as e:
|
| 236 |
+
raise HTTPException(
|
|
|
|
|
|
|
| 237 |
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 238 |
+
detail=f"Cannot refresh token: {str(e)}"
|
| 239 |
)
|
| 240 |
|
| 241 |
|
| 242 |
+
@router.post("/logout")
|
| 243 |
+
async def logout(
|
|
|
|
| 244 |
req: Request,
|
| 245 |
background_tasks: BackgroundTasks,
|
| 246 |
+
user: User = Depends(get_current_user),
|
| 247 |
db: AsyncSession = Depends(get_db)
|
| 248 |
):
|
| 249 |
"""
|
| 250 |
+
Logout current user.
|
| 251 |
+
|
| 252 |
+
Note: JWT tokens are stateless, so this endpoint mainly
|
| 253 |
+
serves to log the action. Frontend should discard the token.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
|
| 255 |
+
For full session invalidation, consider implementing
|
| 256 |
+
a token blacklist or reducing token expiry times.
|
| 257 |
+
"""
|
| 258 |
ip = req.client.host
|
| 259 |
+
|
| 260 |
+
# Log logout
|
| 261 |
+
audit_log = AuditLog(
|
| 262 |
+
user_id=user.user_id,
|
| 263 |
+
action="logout",
|
| 264 |
+
ip_address=ip,
|
| 265 |
+
status="success"
|
| 266 |
+
)
|
| 267 |
+
db.add(audit_log)
|
| 268 |
+
await db.commit()
|
| 269 |
+
|
| 270 |
+
# Sync DB to Drive (Async)
|
| 271 |
+
background_tasks.add_task(drive_service.upload_db)
|
| 272 |
+
|
| 273 |
+
return {"success": True, "message": "Logged out successfully"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
routers/gemini.py
ADDED
|
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gemini Router - API endpoints for Gemini AI services.
|
| 3 |
+
Uses job queue pattern for async processing.
|
| 4 |
+
Authentication via JWT (Authorization: Bearer <token>).
|
| 5 |
+
"""
|
| 6 |
+
import os
|
| 7 |
+
import uuid
|
| 8 |
+
from fastapi import APIRouter, Depends, HTTPException, status
|
| 9 |
+
from fastapi.responses import FileResponse
|
| 10 |
+
from pydantic import BaseModel, Field
|
| 11 |
+
from typing import Optional, Literal
|
| 12 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 13 |
+
from sqlalchemy import select, func
|
| 14 |
+
|
| 15 |
+
from core.database import get_db
|
| 16 |
+
from core.models import User, GeminiJob
|
| 17 |
+
from services.gemini_service import MODELS, DOWNLOADS_DIR
|
| 18 |
+
from dependencies import verify_credits, get_current_user
|
| 19 |
+
from datetime import datetime
|
| 20 |
+
|
| 21 |
+
router = APIRouter(prefix="/gemini", tags=["gemini"])
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# Request/Response Models
|
| 25 |
+
class GenerateAnimationPromptRequest(BaseModel):
|
| 26 |
+
base64_image: str = Field(..., description="Base64 encoded image data")
|
| 27 |
+
mime_type: str = Field(..., description="MIME type of the image (e.g., image/png)")
|
| 28 |
+
custom_prompt: Optional[str] = Field(None, description="Optional custom prompt for analysis")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class EditImageRequest(BaseModel):
|
| 32 |
+
base64_image: str = Field(..., description="Base64 encoded image data")
|
| 33 |
+
mime_type: str = Field(..., description="MIME type of the image")
|
| 34 |
+
prompt: str = Field(..., description="Edit instructions")
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class GenerateVideoRequest(BaseModel):
|
| 38 |
+
base64_image: str = Field(..., description="Base64 encoded image data")
|
| 39 |
+
mime_type: str = Field(..., description="MIME type of the image")
|
| 40 |
+
prompt: str = Field(..., description="Video generation prompt")
|
| 41 |
+
aspect_ratio: Literal["16:9", "9:16"] = Field("16:9", description="Video aspect ratio")
|
| 42 |
+
resolution: Literal["720p", "1080p"] = Field("720p", description="Video resolution")
|
| 43 |
+
number_of_videos: int = Field(1, ge=1, le=4, description="Number of videos to generate")
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class GenerateTextRequest(BaseModel):
|
| 47 |
+
prompt: str = Field(..., description="Text prompt")
|
| 48 |
+
model: Optional[str] = Field(None, description="Model to use (defaults to gemini-2.5-flash)")
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class AnalyzeImageRequest(BaseModel):
|
| 52 |
+
base64_image: str = Field(..., description="Base64 encoded image data")
|
| 53 |
+
mime_type: str = Field(..., description="MIME type of the image")
|
| 54 |
+
prompt: str = Field(..., description="Analysis prompt")
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# Authentication is handled by dependencies.verify_credits
|
| 58 |
+
# which uses JWT tokens from Authorization: Bearer <token> header
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
async def get_queue_position(db: AsyncSession, job_id: str) -> int:
|
| 62 |
+
"""Get the position of a job in the queue."""
|
| 63 |
+
query = select(func.count()).where(
|
| 64 |
+
GeminiJob.status == "queued",
|
| 65 |
+
GeminiJob.created_at < select(GeminiJob.created_at).where(GeminiJob.job_id == job_id).scalar_subquery()
|
| 66 |
+
)
|
| 67 |
+
result = await db.execute(query)
|
| 68 |
+
return result.scalar() + 1
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
async def create_job(
|
| 72 |
+
db: AsyncSession,
|
| 73 |
+
user: User,
|
| 74 |
+
job_type: str,
|
| 75 |
+
input_data: dict
|
| 76 |
+
) -> GeminiJob:
|
| 77 |
+
"""Create a new job in the queue."""
|
| 78 |
+
job_id = f"job_{uuid.uuid4().hex[:16]}"
|
| 79 |
+
|
| 80 |
+
job = GeminiJob(
|
| 81 |
+
job_id=job_id,
|
| 82 |
+
user_id=user.user_id,
|
| 83 |
+
job_type=job_type,
|
| 84 |
+
status="queued",
|
| 85 |
+
input_data=input_data
|
| 86 |
+
)
|
| 87 |
+
db.add(job)
|
| 88 |
+
await db.commit()
|
| 89 |
+
await db.refresh(job)
|
| 90 |
+
|
| 91 |
+
return job
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
@router.post("/generate-animation-prompt")
|
| 95 |
+
async def generate_animation_prompt(
|
| 96 |
+
request: GenerateAnimationPromptRequest,
|
| 97 |
+
user: User = Depends(verify_credits),
|
| 98 |
+
db: AsyncSession = Depends(get_db)
|
| 99 |
+
):
|
| 100 |
+
"""
|
| 101 |
+
Queue an animation prompt generation job.
|
| 102 |
+
"""
|
| 103 |
+
job = await create_job(
|
| 104 |
+
db=db,
|
| 105 |
+
user=user,
|
| 106 |
+
job_type="animation_prompt",
|
| 107 |
+
input_data={
|
| 108 |
+
"base64_image": request.base64_image,
|
| 109 |
+
"mime_type": request.mime_type,
|
| 110 |
+
"custom_prompt": request.custom_prompt
|
| 111 |
+
}
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
position = await get_queue_position(db, job.job_id)
|
| 115 |
+
|
| 116 |
+
return {
|
| 117 |
+
"success": True,
|
| 118 |
+
"job_id": job.job_id,
|
| 119 |
+
"status": "queued",
|
| 120 |
+
"position": position,
|
| 121 |
+
"credits_remaining": user.credits
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
@router.post("/edit-image")
|
| 126 |
+
async def edit_image(
|
| 127 |
+
request: EditImageRequest,
|
| 128 |
+
user: User = Depends(verify_credits),
|
| 129 |
+
db: AsyncSession = Depends(get_db)
|
| 130 |
+
):
|
| 131 |
+
"""
|
| 132 |
+
Queue an image edit job.
|
| 133 |
+
"""
|
| 134 |
+
job = await create_job(
|
| 135 |
+
db=db,
|
| 136 |
+
user=user,
|
| 137 |
+
job_type="image",
|
| 138 |
+
input_data={
|
| 139 |
+
"base64_image": request.base64_image,
|
| 140 |
+
"mime_type": request.mime_type,
|
| 141 |
+
"prompt": request.prompt
|
| 142 |
+
}
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
position = await get_queue_position(db, job.job_id)
|
| 146 |
+
|
| 147 |
+
return {
|
| 148 |
+
"success": True,
|
| 149 |
+
"job_id": job.job_id,
|
| 150 |
+
"status": "queued",
|
| 151 |
+
"position": position,
|
| 152 |
+
"credits_remaining": user.credits
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
@router.post("/generate-video")
|
| 157 |
+
async def generate_video(
|
| 158 |
+
request: GenerateVideoRequest,
|
| 159 |
+
user: User = Depends(verify_credits),
|
| 160 |
+
db: AsyncSession = Depends(get_db)
|
| 161 |
+
):
|
| 162 |
+
"""
|
| 163 |
+
Queue a video generation job.
|
| 164 |
+
"""
|
| 165 |
+
job = await create_job(
|
| 166 |
+
db=db,
|
| 167 |
+
user=user,
|
| 168 |
+
job_type="video",
|
| 169 |
+
input_data={
|
| 170 |
+
"base64_image": request.base64_image,
|
| 171 |
+
"mime_type": request.mime_type,
|
| 172 |
+
"prompt": request.prompt,
|
| 173 |
+
"aspect_ratio": request.aspect_ratio,
|
| 174 |
+
"resolution": request.resolution,
|
| 175 |
+
"number_of_videos": request.number_of_videos
|
| 176 |
+
}
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
position = await get_queue_position(db, job.job_id)
|
| 180 |
+
|
| 181 |
+
return {
|
| 182 |
+
"success": True,
|
| 183 |
+
"job_id": job.job_id,
|
| 184 |
+
"status": "queued",
|
| 185 |
+
"position": position,
|
| 186 |
+
"credits_remaining": user.credits
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
@router.post("/generate-text")
|
| 191 |
+
async def generate_text(
|
| 192 |
+
request: GenerateTextRequest,
|
| 193 |
+
user: User = Depends(verify_credits),
|
| 194 |
+
db: AsyncSession = Depends(get_db)
|
| 195 |
+
):
|
| 196 |
+
"""
|
| 197 |
+
Queue a text generation job.
|
| 198 |
+
"""
|
| 199 |
+
job = await create_job(
|
| 200 |
+
db=db,
|
| 201 |
+
user=user,
|
| 202 |
+
job_type="text",
|
| 203 |
+
input_data={
|
| 204 |
+
"prompt": request.prompt,
|
| 205 |
+
"model": request.model
|
| 206 |
+
}
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
position = await get_queue_position(db, job.job_id)
|
| 210 |
+
|
| 211 |
+
return {
|
| 212 |
+
"success": True,
|
| 213 |
+
"job_id": job.job_id,
|
| 214 |
+
"status": "queued",
|
| 215 |
+
"position": position,
|
| 216 |
+
"credits_remaining": user.credits
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
@router.post("/analyze-image")
|
| 221 |
+
async def analyze_image(
|
| 222 |
+
request: AnalyzeImageRequest,
|
| 223 |
+
user: User = Depends(verify_credits),
|
| 224 |
+
db: AsyncSession = Depends(get_db)
|
| 225 |
+
):
|
| 226 |
+
"""
|
| 227 |
+
Queue an image analysis job.
|
| 228 |
+
"""
|
| 229 |
+
job = await create_job(
|
| 230 |
+
db=db,
|
| 231 |
+
user=user,
|
| 232 |
+
job_type="analyze",
|
| 233 |
+
input_data={
|
| 234 |
+
"base64_image": request.base64_image,
|
| 235 |
+
"mime_type": request.mime_type,
|
| 236 |
+
"prompt": request.prompt
|
| 237 |
+
}
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
position = await get_queue_position(db, job.job_id)
|
| 241 |
+
|
| 242 |
+
return {
|
| 243 |
+
"success": True,
|
| 244 |
+
"job_id": job.job_id,
|
| 245 |
+
"status": "queued",
|
| 246 |
+
"position": position,
|
| 247 |
+
"credits_remaining": user.credits
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
@router.get("/job/{job_id}")
|
| 252 |
+
async def get_job_status(
|
| 253 |
+
job_id: str,
|
| 254 |
+
user: User = Depends(get_current_user),
|
| 255 |
+
db: AsyncSession = Depends(get_db)
|
| 256 |
+
):
|
| 257 |
+
"""
|
| 258 |
+
Get the status of a job.
|
| 259 |
+
Poll this endpoint until status is 'completed' or 'failed'.
|
| 260 |
+
"""
|
| 261 |
+
query = select(GeminiJob).where(
|
| 262 |
+
GeminiJob.job_id == job_id,
|
| 263 |
+
GeminiJob.user_id == user.user_id
|
| 264 |
+
)
|
| 265 |
+
result = await db.execute(query)
|
| 266 |
+
job = result.scalar_one_or_none()
|
| 267 |
+
|
| 268 |
+
if not job:
|
| 269 |
+
raise HTTPException(
|
| 270 |
+
status_code=status.HTTP_404_NOT_FOUND,
|
| 271 |
+
detail="Job not found"
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
response = {
|
| 275 |
+
"success": True,
|
| 276 |
+
"job_id": job.job_id,
|
| 277 |
+
"job_type": job.job_type,
|
| 278 |
+
"status": job.status,
|
| 279 |
+
"created_at": job.created_at.isoformat() if job.created_at else None,
|
| 280 |
+
"credits_remaining": user.credits
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
if job.status == "queued":
|
| 284 |
+
response["position"] = await get_queue_position(db, job.job_id)
|
| 285 |
+
|
| 286 |
+
if job.status == "processing":
|
| 287 |
+
response["started_at"] = job.started_at.isoformat() if job.started_at else None
|
| 288 |
+
|
| 289 |
+
if job.status == "completed":
|
| 290 |
+
response["completed_at"] = job.completed_at.isoformat() if job.completed_at else None
|
| 291 |
+
response["output"] = job.output_data
|
| 292 |
+
|
| 293 |
+
# For video jobs, add download URL
|
| 294 |
+
if job.job_type == "video" and job.output_data and job.output_data.get("filename"):
|
| 295 |
+
response["download_url"] = f"/gemini/download/{job.job_id}"
|
| 296 |
+
|
| 297 |
+
if job.status == "failed":
|
| 298 |
+
response["error"] = job.error_message
|
| 299 |
+
response["completed_at"] = job.completed_at.isoformat() if job.completed_at else None
|
| 300 |
+
|
| 301 |
+
return response
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
@router.get("/download/{job_id}")
|
| 305 |
+
async def download_video(
|
| 306 |
+
job_id: str,
|
| 307 |
+
user: User = Depends(get_current_user),
|
| 308 |
+
db: AsyncSession = Depends(get_db)
|
| 309 |
+
):
|
| 310 |
+
"""
|
| 311 |
+
Download a generated video.
|
| 312 |
+
"""
|
| 313 |
+
query = select(GeminiJob).where(
|
| 314 |
+
GeminiJob.job_id == job_id,
|
| 315 |
+
GeminiJob.user_id == user.user_id,
|
| 316 |
+
GeminiJob.job_type == "video"
|
| 317 |
+
)
|
| 318 |
+
result = await db.execute(query)
|
| 319 |
+
job = result.scalar_one_or_none()
|
| 320 |
+
|
| 321 |
+
if not job:
|
| 322 |
+
raise HTTPException(
|
| 323 |
+
status_code=status.HTTP_404_NOT_FOUND,
|
| 324 |
+
detail="Job not found"
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
if job.status != "completed" or not job.output_data or not job.output_data.get("filename"):
|
| 328 |
+
raise HTTPException(
|
| 329 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 330 |
+
detail="Video not ready for download"
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
filepath = os.path.join(DOWNLOADS_DIR, job.output_data["filename"])
|
| 334 |
+
if not os.path.exists(filepath):
|
| 335 |
+
raise HTTPException(
|
| 336 |
+
status_code=status.HTTP_404_NOT_FOUND,
|
| 337 |
+
detail="Video file not found"
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
return FileResponse(
|
| 341 |
+
path=filepath,
|
| 342 |
+
media_type="video/mp4",
|
| 343 |
+
filename=f"video_{job_id}.mp4"
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
@router.get("/models")
|
| 348 |
+
async def get_models():
|
| 349 |
+
"""
|
| 350 |
+
Get available model names.
|
| 351 |
+
"""
|
| 352 |
+
return {"models": MODELS}
|
services/drive_service.py
CHANGED
|
@@ -24,9 +24,10 @@ class DriveService:
|
|
| 24 |
def __init__(self):
|
| 25 |
self.creds = None
|
| 26 |
self.service = None
|
| 27 |
-
|
| 28 |
-
self.
|
| 29 |
-
self.
|
|
|
|
| 30 |
|
| 31 |
def authenticate(self):
|
| 32 |
"""Authenticate using the refresh token."""
|
|
|
|
| 24 |
def __init__(self):
|
| 25 |
self.creds = None
|
| 26 |
self.service = None
|
| 27 |
+
# Server-side credentials for Drive API
|
| 28 |
+
self.client_id = os.getenv('SERVER_GOOGLE_CLIENT_ID') or os.getenv('GOOGLE_CLIENT_ID')
|
| 29 |
+
self.client_secret = os.getenv('SERVER_GOOGLE_CLIENT_SECRET') or os.getenv('GOOGLE_CLIENT_SECRET')
|
| 30 |
+
self.refresh_token = os.getenv('SERVER_GOOGLE_REFRESH_TOKEN') or os.getenv('GOOGLE_REFRESH_TOKEN')
|
| 31 |
|
| 32 |
def authenticate(self):
|
| 33 |
"""Authenticate using the refresh token."""
|
services/email_service.py
CHANGED
|
@@ -32,9 +32,10 @@ class GmailService:
|
|
| 32 |
def __init__(self):
|
| 33 |
self.creds = None
|
| 34 |
self.service = None
|
| 35 |
-
|
| 36 |
-
self.
|
| 37 |
-
self.
|
|
|
|
| 38 |
self.sender_email = os.getenv('SMTP_SENDER') or os.getenv('EMAIL_ID')
|
| 39 |
|
| 40 |
def authenticate(self):
|
|
|
|
| 32 |
def __init__(self):
|
| 33 |
self.creds = None
|
| 34 |
self.service = None
|
| 35 |
+
# Server-side credentials for Gmail API
|
| 36 |
+
self.client_id = os.getenv('SERVER_GOOGLE_CLIENT_ID') or os.getenv('GOOGLE_CLIENT_ID')
|
| 37 |
+
self.client_secret = os.getenv('SERVER_GOOGLE_CLIENT_SECRET') or os.getenv('GOOGLE_CLIENT_SECRET')
|
| 38 |
+
self.refresh_token = os.getenv('SERVER_GOOGLE_REFRESH_TOKEN') or os.getenv('GOOGLE_REFRESH_TOKEN')
|
| 39 |
self.sender_email = os.getenv('SMTP_SENDER') or os.getenv('EMAIL_ID')
|
| 40 |
|
| 41 |
def authenticate(self):
|
services/gemini_service.py
ADDED
|
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gemini AI Service for image and video generation.
|
| 3 |
+
Python port of the TypeScript geminiService.ts
|
| 4 |
+
Uses server-side API key from environment.
|
| 5 |
+
"""
|
| 6 |
+
import asyncio
|
| 7 |
+
import logging
|
| 8 |
+
import os
|
| 9 |
+
import uuid
|
| 10 |
+
import httpx
|
| 11 |
+
from typing import Optional, Literal
|
| 12 |
+
from google import genai
|
| 13 |
+
from google.genai import types
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
# Model names - easily configurable
|
| 18 |
+
MODELS = {
|
| 19 |
+
"text_generation": "gemini-2.5-flash",
|
| 20 |
+
"image_edit": "gemini-2.5-flash-image",
|
| 21 |
+
"video_generation": "veo-3.1-fast-generate-preview"
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
# Type aliases
|
| 25 |
+
AspectRatio = Literal["16:9", "9:16"]
|
| 26 |
+
Resolution = Literal["720p", "1080p"]
|
| 27 |
+
|
| 28 |
+
# Video downloads directory
|
| 29 |
+
DOWNLOADS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "downloads")
|
| 30 |
+
|
| 31 |
+
# Ensure downloads directory exists
|
| 32 |
+
os.makedirs(DOWNLOADS_DIR, exist_ok=True)
|
| 33 |
+
|
| 34 |
+
# Concurrency limits from environment (defaults)
|
| 35 |
+
MAX_CONCURRENT_VIDEOS = int(os.getenv("MAX_CONCURRENT_VIDEOS", "2"))
|
| 36 |
+
MAX_CONCURRENT_IMAGES = int(os.getenv("MAX_CONCURRENT_IMAGES", "5"))
|
| 37 |
+
MAX_CONCURRENT_TEXT = int(os.getenv("MAX_CONCURRENT_TEXT", "10"))
|
| 38 |
+
|
| 39 |
+
# Semaphores for concurrency control
|
| 40 |
+
_video_semaphore: Optional[asyncio.Semaphore] = None
|
| 41 |
+
_image_semaphore: Optional[asyncio.Semaphore] = None
|
| 42 |
+
_text_semaphore: Optional[asyncio.Semaphore] = None
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def get_video_semaphore() -> asyncio.Semaphore:
|
| 46 |
+
"""Get or create video semaphore."""
|
| 47 |
+
global _video_semaphore
|
| 48 |
+
if _video_semaphore is None:
|
| 49 |
+
_video_semaphore = asyncio.Semaphore(MAX_CONCURRENT_VIDEOS)
|
| 50 |
+
logger.info(f"Video semaphore initialized with limit: {MAX_CONCURRENT_VIDEOS}")
|
| 51 |
+
return _video_semaphore
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def get_image_semaphore() -> asyncio.Semaphore:
|
| 55 |
+
"""Get or create image semaphore."""
|
| 56 |
+
global _image_semaphore
|
| 57 |
+
if _image_semaphore is None:
|
| 58 |
+
_image_semaphore = asyncio.Semaphore(MAX_CONCURRENT_IMAGES)
|
| 59 |
+
logger.info(f"Image semaphore initialized with limit: {MAX_CONCURRENT_IMAGES}")
|
| 60 |
+
return _image_semaphore
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def get_text_semaphore() -> asyncio.Semaphore:
|
| 64 |
+
"""Get or create text semaphore."""
|
| 65 |
+
global _text_semaphore
|
| 66 |
+
if _text_semaphore is None:
|
| 67 |
+
_text_semaphore = asyncio.Semaphore(MAX_CONCURRENT_TEXT)
|
| 68 |
+
logger.info(f"Text semaphore initialized with limit: {MAX_CONCURRENT_TEXT}")
|
| 69 |
+
return _text_semaphore
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def get_gemini_api_key() -> str:
|
| 73 |
+
"""Get Gemini API key from environment."""
|
| 74 |
+
api_key = os.getenv("GEMINI_API_KEY")
|
| 75 |
+
if not api_key:
|
| 76 |
+
raise ValueError("GEMINI_API_KEY environment variable not set")
|
| 77 |
+
return api_key
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class GeminiService:
|
| 81 |
+
"""
|
| 82 |
+
Gemini AI Service for text, image, and video generation.
|
| 83 |
+
Uses server-side API key from environment.
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
def __init__(self, api_key: Optional[str] = None):
|
| 87 |
+
"""Initialize the Gemini client with API key from env or provided."""
|
| 88 |
+
self.api_key = api_key or get_gemini_api_key()
|
| 89 |
+
self.client = genai.Client(api_key=self.api_key)
|
| 90 |
+
|
| 91 |
+
def _handle_api_error(self, error: Exception, context: str):
|
| 92 |
+
"""Handle API errors with descriptive messages."""
|
| 93 |
+
msg = str(error)
|
| 94 |
+
if "404" in msg or "NOT_FOUND" in msg or "Requested entity was not found" in msg or "[5," in msg:
|
| 95 |
+
raise ValueError(
|
| 96 |
+
f"Model not found ({context}). Ensure your API key project has access to this model. "
|
| 97 |
+
"Veo requires a paid account."
|
| 98 |
+
)
|
| 99 |
+
raise error
|
| 100 |
+
|
| 101 |
+
async def generate_animation_prompt(
|
| 102 |
+
self,
|
| 103 |
+
base64_image: str,
|
| 104 |
+
mime_type: str,
|
| 105 |
+
custom_prompt: Optional[str] = None
|
| 106 |
+
) -> str:
|
| 107 |
+
"""
|
| 108 |
+
Analyzes the image to generate a suitable animation prompt.
|
| 109 |
+
"""
|
| 110 |
+
default_prompt = custom_prompt or "Describe how this image could be subtly animated with cinematic movement."
|
| 111 |
+
async with get_text_semaphore():
|
| 112 |
+
try:
|
| 113 |
+
response = await asyncio.to_thread(
|
| 114 |
+
self.client.models.generate_content,
|
| 115 |
+
model=MODELS["text_generation"],
|
| 116 |
+
contents=types.Content(
|
| 117 |
+
parts=[
|
| 118 |
+
types.Part.from_bytes(
|
| 119 |
+
data=base64_image,
|
| 120 |
+
mime_type=mime_type
|
| 121 |
+
),
|
| 122 |
+
types.Part.from_text(text=default_prompt)
|
| 123 |
+
]
|
| 124 |
+
)
|
| 125 |
+
)
|
| 126 |
+
return response.text or "Cinematic subtle movement"
|
| 127 |
+
except Exception as error:
|
| 128 |
+
self._handle_api_error(error, MODELS["text_generation"])
|
| 129 |
+
|
| 130 |
+
async def edit_image(
|
| 131 |
+
self,
|
| 132 |
+
base64_image: str,
|
| 133 |
+
mime_type: str,
|
| 134 |
+
prompt: str
|
| 135 |
+
) -> str:
|
| 136 |
+
"""
|
| 137 |
+
Edit an image using Gemini image model.
|
| 138 |
+
Returns base64 data URI of the edited image.
|
| 139 |
+
"""
|
| 140 |
+
async with get_image_semaphore():
|
| 141 |
+
try:
|
| 142 |
+
response = await asyncio.to_thread(
|
| 143 |
+
self.client.models.generate_content,
|
| 144 |
+
model=MODELS["image_edit"],
|
| 145 |
+
contents=types.Content(
|
| 146 |
+
parts=[
|
| 147 |
+
types.Part.from_bytes(
|
| 148 |
+
data=base64_image,
|
| 149 |
+
mime_type=mime_type
|
| 150 |
+
),
|
| 151 |
+
types.Part.from_text(text=prompt or "Enhance this image")
|
| 152 |
+
]
|
| 153 |
+
)
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
candidates = response.candidates
|
| 157 |
+
if not candidates:
|
| 158 |
+
raise ValueError("No candidates returned from Gemini.")
|
| 159 |
+
|
| 160 |
+
for part in candidates[0].content.parts:
|
| 161 |
+
if hasattr(part, 'inline_data') and part.inline_data and part.inline_data.data:
|
| 162 |
+
result_mime = part.inline_data.mime_type or 'image/png'
|
| 163 |
+
return f"data:{result_mime};base64,{part.inline_data.data}"
|
| 164 |
+
|
| 165 |
+
raise ValueError("No image data found in the response.")
|
| 166 |
+
except Exception as error:
|
| 167 |
+
self._handle_api_error(error, MODELS["image_edit"])
|
| 168 |
+
|
| 169 |
+
async def start_video_generation(
|
| 170 |
+
self,
|
| 171 |
+
base64_image: str,
|
| 172 |
+
mime_type: str,
|
| 173 |
+
prompt: str,
|
| 174 |
+
aspect_ratio: AspectRatio = "16:9",
|
| 175 |
+
resolution: Resolution = "720p",
|
| 176 |
+
number_of_videos: int = 1
|
| 177 |
+
) -> dict:
|
| 178 |
+
"""
|
| 179 |
+
Start video generation using Veo model.
|
| 180 |
+
Returns operation details for polling.
|
| 181 |
+
"""
|
| 182 |
+
async with get_video_semaphore():
|
| 183 |
+
try:
|
| 184 |
+
# Start video generation
|
| 185 |
+
operation = await asyncio.to_thread(
|
| 186 |
+
self.client.models.generate_videos,
|
| 187 |
+
model=MODELS["video_generation"],
|
| 188 |
+
prompt=prompt,
|
| 189 |
+
image=types.Image(
|
| 190 |
+
image_bytes=base64_image,
|
| 191 |
+
mime_type=mime_type
|
| 192 |
+
),
|
| 193 |
+
config=types.GenerateVideosConfig(
|
| 194 |
+
number_of_videos=number_of_videos,
|
| 195 |
+
resolution=resolution,
|
| 196 |
+
aspect_ratio=aspect_ratio
|
| 197 |
+
)
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
# Return operation details
|
| 201 |
+
return {
|
| 202 |
+
"gemini_operation_name": operation.name,
|
| 203 |
+
"done": operation.done,
|
| 204 |
+
"status": "completed" if operation.done else "pending"
|
| 205 |
+
}
|
| 206 |
+
except Exception as error:
|
| 207 |
+
self._handle_api_error(error, MODELS["video_generation"])
|
| 208 |
+
|
| 209 |
+
async def check_video_status(self, gemini_operation_name: str) -> dict:
|
| 210 |
+
"""
|
| 211 |
+
Check the status of a video generation operation.
|
| 212 |
+
Returns status and video URL if complete.
|
| 213 |
+
"""
|
| 214 |
+
try:
|
| 215 |
+
# Get operation status
|
| 216 |
+
operation = await asyncio.to_thread(
|
| 217 |
+
self.client.operations.get,
|
| 218 |
+
name=gemini_operation_name
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
if not operation.done:
|
| 222 |
+
return {
|
| 223 |
+
"gemini_operation_name": gemini_operation_name,
|
| 224 |
+
"done": False,
|
| 225 |
+
"status": "pending"
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
if operation.error:
|
| 229 |
+
return {
|
| 230 |
+
"gemini_operation_name": gemini_operation_name,
|
| 231 |
+
"done": True,
|
| 232 |
+
"status": "failed",
|
| 233 |
+
"error": operation.error.message or "Unknown error"
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
# Extract video URI
|
| 237 |
+
generated_videos = getattr(operation.response, 'generated_videos', None)
|
| 238 |
+
if not generated_videos:
|
| 239 |
+
return {
|
| 240 |
+
"gemini_operation_name": gemini_operation_name,
|
| 241 |
+
"done": True,
|
| 242 |
+
"status": "failed",
|
| 243 |
+
"error": "No video URI returned. May be due to safety filters."
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
video_uri = generated_videos[0].video.uri if generated_videos[0].video else None
|
| 247 |
+
if not video_uri:
|
| 248 |
+
return {
|
| 249 |
+
"gemini_operation_name": gemini_operation_name,
|
| 250 |
+
"done": True,
|
| 251 |
+
"status": "failed",
|
| 252 |
+
"error": "No video URI in response."
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
# Return success with video URL (internal - will be downloaded by router)
|
| 256 |
+
return {
|
| 257 |
+
"gemini_operation_name": gemini_operation_name,
|
| 258 |
+
"done": True,
|
| 259 |
+
"status": "completed",
|
| 260 |
+
"video_url": f"{video_uri}&key={self.api_key}"
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
except Exception as error:
|
| 264 |
+
msg = str(error)
|
| 265 |
+
if "404" in msg or "NOT_FOUND" in msg or "Requested entity was not found" in msg:
|
| 266 |
+
return {
|
| 267 |
+
"gemini_operation_name": gemini_operation_name,
|
| 268 |
+
"done": True,
|
| 269 |
+
"status": "failed",
|
| 270 |
+
"error": "Operation not found (404). It may have expired."
|
| 271 |
+
}
|
| 272 |
+
raise error
|
| 273 |
+
|
| 274 |
+
async def download_video(self, video_url: str, operation_id: str) -> str:
|
| 275 |
+
"""
|
| 276 |
+
Download video from Gemini to local storage.
|
| 277 |
+
Returns the local filename.
|
| 278 |
+
"""
|
| 279 |
+
filename = f"{operation_id}.mp4"
|
| 280 |
+
filepath = os.path.join(DOWNLOADS_DIR, filename)
|
| 281 |
+
|
| 282 |
+
try:
|
| 283 |
+
async with httpx.AsyncClient(timeout=120.0) as client:
|
| 284 |
+
response = await client.get(video_url)
|
| 285 |
+
response.raise_for_status()
|
| 286 |
+
|
| 287 |
+
with open(filepath, 'wb') as f:
|
| 288 |
+
f.write(response.content)
|
| 289 |
+
|
| 290 |
+
logger.info(f"Downloaded video to {filepath}")
|
| 291 |
+
return filename
|
| 292 |
+
except Exception as e:
|
| 293 |
+
logger.error(f"Failed to download video: {e}")
|
| 294 |
+
raise ValueError(f"Failed to download video: {e}")
|
| 295 |
+
|
| 296 |
+
async def generate_text(
|
| 297 |
+
self,
|
| 298 |
+
prompt: str,
|
| 299 |
+
model: Optional[str] = None
|
| 300 |
+
) -> str:
|
| 301 |
+
"""
|
| 302 |
+
Simple text generation with Gemini.
|
| 303 |
+
"""
|
| 304 |
+
model_name = model or MODELS["text_generation"]
|
| 305 |
+
async with get_text_semaphore():
|
| 306 |
+
try:
|
| 307 |
+
response = await asyncio.to_thread(
|
| 308 |
+
self.client.models.generate_content,
|
| 309 |
+
model=model_name,
|
| 310 |
+
contents=types.Content(
|
| 311 |
+
parts=[types.Part.from_text(text=prompt)]
|
| 312 |
+
)
|
| 313 |
+
)
|
| 314 |
+
return response.text or ""
|
| 315 |
+
except Exception as error:
|
| 316 |
+
self._handle_api_error(error, model_name)
|
| 317 |
+
|
| 318 |
+
async def analyze_image(
|
| 319 |
+
self,
|
| 320 |
+
base64_image: str,
|
| 321 |
+
mime_type: str,
|
| 322 |
+
prompt: str
|
| 323 |
+
) -> str:
|
| 324 |
+
"""
|
| 325 |
+
Analyze image with custom prompt.
|
| 326 |
+
"""
|
| 327 |
+
async with get_text_semaphore():
|
| 328 |
+
try:
|
| 329 |
+
response = await asyncio.to_thread(
|
| 330 |
+
self.client.models.generate_content,
|
| 331 |
+
model=MODELS["text_generation"],
|
| 332 |
+
contents=types.Content(
|
| 333 |
+
parts=[
|
| 334 |
+
types.Part.from_bytes(
|
| 335 |
+
data=base64_image,
|
| 336 |
+
mime_type=mime_type
|
| 337 |
+
),
|
| 338 |
+
types.Part.from_text(text=prompt)
|
| 339 |
+
]
|
| 340 |
+
)
|
| 341 |
+
)
|
| 342 |
+
return response.text or ""
|
| 343 |
+
except Exception as error:
|
| 344 |
+
self._handle_api_error(error, MODELS["text_generation"])
|
services/google_auth_service.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Modular Google OAuth Service
|
| 3 |
+
|
| 4 |
+
A self-contained, plug-and-play service for verifying Google ID tokens.
|
| 5 |
+
Can be used in any Python application with minimal configuration.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
from services.google_auth_service import GoogleAuthService, GoogleUserInfo
|
| 9 |
+
|
| 10 |
+
# Initialize with client ID
|
| 11 |
+
auth_service = GoogleAuthService(client_id="your-google-client-id")
|
| 12 |
+
|
| 13 |
+
# Or use environment variable GOOGLE_CLIENT_ID
|
| 14 |
+
auth_service = GoogleAuthService()
|
| 15 |
+
|
| 16 |
+
# Verify a Google ID token
|
| 17 |
+
user_info = auth_service.verify_token(id_token)
|
| 18 |
+
print(user_info.email, user_info.google_id, user_info.name)
|
| 19 |
+
|
| 20 |
+
Environment Variables:
|
| 21 |
+
GOOGLE_CLIENT_ID: Your Google OAuth 2.0 Client ID
|
| 22 |
+
|
| 23 |
+
Dependencies:
|
| 24 |
+
google-auth>=2.0.0
|
| 25 |
+
google-auth-oauthlib>=1.0.0
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
import os
|
| 29 |
+
import logging
|
| 30 |
+
from dataclasses import dataclass
|
| 31 |
+
from typing import Optional
|
| 32 |
+
from google.oauth2 import id_token as google_id_token
|
| 33 |
+
from google.auth.transport import requests as google_requests
|
| 34 |
+
|
| 35 |
+
logger = logging.getLogger(__name__)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class GoogleUserInfo:
|
| 40 |
+
"""
|
| 41 |
+
User information extracted from a verified Google ID token.
|
| 42 |
+
|
| 43 |
+
Attributes:
|
| 44 |
+
google_id: Unique Google user identifier (sub claim)
|
| 45 |
+
email: User's email address
|
| 46 |
+
email_verified: Whether Google has verified the email
|
| 47 |
+
name: User's display name (may be None)
|
| 48 |
+
picture: URL to user's profile picture (may be None)
|
| 49 |
+
given_name: User's first name (may be None)
|
| 50 |
+
family_name: User's last name (may be None)
|
| 51 |
+
locale: User's locale preference (may be None)
|
| 52 |
+
"""
|
| 53 |
+
google_id: str
|
| 54 |
+
email: str
|
| 55 |
+
email_verified: bool = True
|
| 56 |
+
name: Optional[str] = None
|
| 57 |
+
picture: Optional[str] = None
|
| 58 |
+
given_name: Optional[str] = None
|
| 59 |
+
family_name: Optional[str] = None
|
| 60 |
+
locale: Optional[str] = None
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class GoogleAuthError(Exception):
|
| 64 |
+
"""Base exception for Google Auth errors."""
|
| 65 |
+
pass
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class InvalidTokenError(GoogleAuthError):
|
| 69 |
+
"""Raised when the token is invalid or expired."""
|
| 70 |
+
pass
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class ConfigurationError(GoogleAuthError):
|
| 74 |
+
"""Raised when the service is not properly configured."""
|
| 75 |
+
pass
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class GoogleAuthService:
|
| 79 |
+
"""
|
| 80 |
+
Service for verifying Google OAuth ID tokens.
|
| 81 |
+
|
| 82 |
+
This service validates ID tokens issued by Google Sign-In and extracts
|
| 83 |
+
user information. It's designed to be modular and reusable across
|
| 84 |
+
different applications.
|
| 85 |
+
|
| 86 |
+
Example:
|
| 87 |
+
service = GoogleAuthService()
|
| 88 |
+
try:
|
| 89 |
+
user_info = service.verify_token(token_from_frontend)
|
| 90 |
+
print(f"Welcome {user_info.name}!")
|
| 91 |
+
except InvalidTokenError:
|
| 92 |
+
print("Invalid or expired token")
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
def __init__(
|
| 96 |
+
self,
|
| 97 |
+
client_id: Optional[str] = None,
|
| 98 |
+
clock_skew_seconds: int = 0
|
| 99 |
+
):
|
| 100 |
+
"""
|
| 101 |
+
Initialize the Google Auth Service.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
client_id: Google OAuth 2.0 Client ID. If not provided,
|
| 105 |
+
falls back to GOOGLE_CLIENT_ID environment variable.
|
| 106 |
+
clock_skew_seconds: Allowed clock skew in seconds for token
|
| 107 |
+
validation (default: 0).
|
| 108 |
+
|
| 109 |
+
Raises:
|
| 110 |
+
ConfigurationError: If no client_id is provided or found.
|
| 111 |
+
"""
|
| 112 |
+
self.client_id = client_id or os.getenv("AUTH_SIGN_IN_GOOGLE_CLIENT_ID")
|
| 113 |
+
self.clock_skew_seconds = clock_skew_seconds
|
| 114 |
+
|
| 115 |
+
if not self.client_id:
|
| 116 |
+
raise ConfigurationError(
|
| 117 |
+
"Google Client ID is required. Either pass client_id parameter "
|
| 118 |
+
"or set GOOGLE_CLIENT_ID environment variable."
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
logger.info(f"GoogleAuthService initialized with client_id: {self.client_id[:20]}...")
|
| 122 |
+
|
| 123 |
+
def verify_token(self, id_token: str) -> GoogleUserInfo:
|
| 124 |
+
"""
|
| 125 |
+
Verify a Google ID token and extract user information.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
id_token: The ID token received from the frontend after
|
| 129 |
+
Google Sign-In.
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
GoogleUserInfo: Dataclass containing user's Google profile info.
|
| 133 |
+
|
| 134 |
+
Raises:
|
| 135 |
+
InvalidTokenError: If the token is invalid, expired, or
|
| 136 |
+
doesn't match the expected client ID.
|
| 137 |
+
"""
|
| 138 |
+
if not id_token:
|
| 139 |
+
raise InvalidTokenError("Token cannot be empty")
|
| 140 |
+
|
| 141 |
+
try:
|
| 142 |
+
# Verify the token with Google
|
| 143 |
+
idinfo = google_id_token.verify_oauth2_token(
|
| 144 |
+
id_token,
|
| 145 |
+
google_requests.Request(),
|
| 146 |
+
self.client_id,
|
| 147 |
+
clock_skew_in_seconds=self.clock_skew_seconds
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
# Validate issuer
|
| 151 |
+
if idinfo.get("iss") not in ["accounts.google.com", "https://accounts.google.com"]:
|
| 152 |
+
raise InvalidTokenError("Invalid token issuer")
|
| 153 |
+
|
| 154 |
+
# Validate audience
|
| 155 |
+
if idinfo.get("aud") != self.client_id:
|
| 156 |
+
raise InvalidTokenError("Token was not issued for this application")
|
| 157 |
+
|
| 158 |
+
# Extract user info
|
| 159 |
+
return GoogleUserInfo(
|
| 160 |
+
google_id=idinfo["sub"],
|
| 161 |
+
email=idinfo["email"],
|
| 162 |
+
email_verified=idinfo.get("email_verified", False),
|
| 163 |
+
name=idinfo.get("name"),
|
| 164 |
+
picture=idinfo.get("picture"),
|
| 165 |
+
given_name=idinfo.get("given_name"),
|
| 166 |
+
family_name=idinfo.get("family_name"),
|
| 167 |
+
locale=idinfo.get("locale")
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
except ValueError as e:
|
| 171 |
+
logger.warning(f"Token verification failed: {e}")
|
| 172 |
+
raise InvalidTokenError(f"Token verification failed: {str(e)}")
|
| 173 |
+
except Exception as e:
|
| 174 |
+
logger.error(f"Unexpected error during token verification: {e}")
|
| 175 |
+
raise InvalidTokenError(f"Token verification error: {str(e)}")
|
| 176 |
+
|
| 177 |
+
def verify_token_safe(self, id_token: str) -> Optional[GoogleUserInfo]:
|
| 178 |
+
"""
|
| 179 |
+
Verify a Google ID token without raising exceptions.
|
| 180 |
+
|
| 181 |
+
Useful for cases where you want to check validity without
|
| 182 |
+
exception handling.
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
id_token: The ID token to verify.
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
GoogleUserInfo if valid, None if invalid.
|
| 189 |
+
"""
|
| 190 |
+
try:
|
| 191 |
+
return self.verify_token(id_token)
|
| 192 |
+
except GoogleAuthError:
|
| 193 |
+
return None
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
# Singleton instance for convenience (initialized on first use)
|
| 197 |
+
_default_service: Optional[GoogleAuthService] = None
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def get_google_auth_service() -> GoogleAuthService:
|
| 201 |
+
"""
|
| 202 |
+
Get the default GoogleAuthService instance.
|
| 203 |
+
|
| 204 |
+
Creates a singleton instance using environment variables.
|
| 205 |
+
|
| 206 |
+
Returns:
|
| 207 |
+
GoogleAuthService: The default service instance.
|
| 208 |
+
|
| 209 |
+
Raises:
|
| 210 |
+
ConfigurationError: If GOOGLE_CLIENT_ID is not set.
|
| 211 |
+
"""
|
| 212 |
+
global _default_service
|
| 213 |
+
if _default_service is None:
|
| 214 |
+
_default_service = GoogleAuthService()
|
| 215 |
+
return _default_service
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def verify_google_token(id_token: str) -> GoogleUserInfo:
|
| 219 |
+
"""
|
| 220 |
+
Convenience function to verify a token using the default service.
|
| 221 |
+
|
| 222 |
+
Args:
|
| 223 |
+
id_token: The Google ID token to verify.
|
| 224 |
+
|
| 225 |
+
Returns:
|
| 226 |
+
GoogleUserInfo: Verified user information.
|
| 227 |
+
|
| 228 |
+
Raises:
|
| 229 |
+
InvalidTokenError: If verification fails.
|
| 230 |
+
ConfigurationError: If service is not configured.
|
| 231 |
+
"""
|
| 232 |
+
return get_google_auth_service().verify_token(id_token)
|
services/job_worker.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Background worker for processing Gemini jobs.
|
| 3 |
+
Runs continuously, picks up queued jobs, processes them.
|
| 4 |
+
"""
|
| 5 |
+
import asyncio
|
| 6 |
+
import logging
|
| 7 |
+
import os
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
from sqlalchemy import select
|
| 10 |
+
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
|
| 11 |
+
|
| 12 |
+
from core.database import DATABASE_URL
|
| 13 |
+
from core.models import GeminiJob
|
| 14 |
+
from services.gemini_service import GeminiService, DOWNLOADS_DIR
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
# Worker configuration
|
| 19 |
+
WORKER_POLL_INTERVAL = int(os.getenv("WORKER_POLL_INTERVAL", "5")) # seconds
|
| 20 |
+
MAX_CONCURRENT_VIDEO_JOBS = int(os.getenv("MAX_CONCURRENT_VIDEO_JOBS", "2"))
|
| 21 |
+
MAX_CONCURRENT_IMAGE_JOBS = int(os.getenv("MAX_CONCURRENT_IMAGE_JOBS", "3"))
|
| 22 |
+
MAX_CONCURRENT_TEXT_JOBS = int(os.getenv("MAX_CONCURRENT_TEXT_JOBS", "5"))
|
| 23 |
+
|
| 24 |
+
# Track running jobs
|
| 25 |
+
_running_jobs = {
|
| 26 |
+
"video": 0,
|
| 27 |
+
"image": 0,
|
| 28 |
+
"text": 0,
|
| 29 |
+
"analyze": 0,
|
| 30 |
+
"animation_prompt": 0
|
| 31 |
+
}
|
| 32 |
+
_lock = asyncio.Lock()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class JobWorker:
|
| 36 |
+
"""Background worker for processing Gemini jobs."""
|
| 37 |
+
|
| 38 |
+
def __init__(self):
|
| 39 |
+
self.engine = create_async_engine(DATABASE_URL, echo=False)
|
| 40 |
+
self.async_session = async_sessionmaker(
|
| 41 |
+
self.engine,
|
| 42 |
+
class_=AsyncSession,
|
| 43 |
+
expire_on_commit=False
|
| 44 |
+
)
|
| 45 |
+
self._running = False
|
| 46 |
+
self._tasks = []
|
| 47 |
+
|
| 48 |
+
async def start(self):
|
| 49 |
+
"""Start the worker."""
|
| 50 |
+
self._running = True
|
| 51 |
+
logger.info("Job worker started")
|
| 52 |
+
asyncio.create_task(self._poll_loop())
|
| 53 |
+
|
| 54 |
+
async def stop(self):
|
| 55 |
+
"""Stop the worker."""
|
| 56 |
+
self._running = False
|
| 57 |
+
# Wait for running tasks to complete
|
| 58 |
+
if self._tasks:
|
| 59 |
+
await asyncio.gather(*self._tasks, return_exceptions=True)
|
| 60 |
+
logger.info("Job worker stopped")
|
| 61 |
+
|
| 62 |
+
async def _poll_loop(self):
|
| 63 |
+
"""Main polling loop."""
|
| 64 |
+
while self._running:
|
| 65 |
+
try:
|
| 66 |
+
await self._process_queued_jobs()
|
| 67 |
+
except Exception as e:
|
| 68 |
+
logger.error(f"Error in worker poll loop: {e}")
|
| 69 |
+
await asyncio.sleep(WORKER_POLL_INTERVAL)
|
| 70 |
+
|
| 71 |
+
async def _process_queued_jobs(self):
|
| 72 |
+
"""Find and process queued jobs."""
|
| 73 |
+
async with self.async_session() as session:
|
| 74 |
+
# Get queued jobs
|
| 75 |
+
query = select(GeminiJob).where(
|
| 76 |
+
GeminiJob.status == "queued"
|
| 77 |
+
).order_by(GeminiJob.created_at)
|
| 78 |
+
|
| 79 |
+
result = await session.execute(query)
|
| 80 |
+
jobs = result.scalars().all()
|
| 81 |
+
|
| 82 |
+
for job in jobs:
|
| 83 |
+
# Check if we can process this job type
|
| 84 |
+
if not await self._can_process(job.job_type):
|
| 85 |
+
continue
|
| 86 |
+
|
| 87 |
+
# Mark as processing
|
| 88 |
+
async with _lock:
|
| 89 |
+
_running_jobs[job.job_type] = _running_jobs.get(job.job_type, 0) + 1
|
| 90 |
+
|
| 91 |
+
try:
|
| 92 |
+
job.status = "processing"
|
| 93 |
+
job.started_at = datetime.utcnow()
|
| 94 |
+
await session.commit()
|
| 95 |
+
|
| 96 |
+
# Process job in background
|
| 97 |
+
task = asyncio.create_task(self._process_job(job.job_id))
|
| 98 |
+
self._tasks.append(task)
|
| 99 |
+
task.add_done_callback(lambda t: self._tasks.remove(t) if t in self._tasks else None)
|
| 100 |
+
except Exception as e:
|
| 101 |
+
logger.error(f"Error starting job {job.job_id}: {e}")
|
| 102 |
+
async with _lock:
|
| 103 |
+
_running_jobs[job.job_type] = max(0, _running_jobs.get(job.job_type, 0) - 1)
|
| 104 |
+
|
| 105 |
+
async def _can_process(self, job_type: str) -> bool:
|
| 106 |
+
"""Check if we can process another job of this type."""
|
| 107 |
+
async with _lock:
|
| 108 |
+
current = _running_jobs.get(job_type, 0)
|
| 109 |
+
if job_type == "video":
|
| 110 |
+
return current < MAX_CONCURRENT_VIDEO_JOBS
|
| 111 |
+
elif job_type in ("image", "edit_image"):
|
| 112 |
+
return current < MAX_CONCURRENT_IMAGE_JOBS
|
| 113 |
+
else: # text, analyze, animation_prompt
|
| 114 |
+
return current < MAX_CONCURRENT_TEXT_JOBS
|
| 115 |
+
|
| 116 |
+
async def _process_job(self, job_id: str):
|
| 117 |
+
"""Process a single job."""
|
| 118 |
+
async with self.async_session() as session:
|
| 119 |
+
try:
|
| 120 |
+
# Get the job
|
| 121 |
+
query = select(GeminiJob).where(GeminiJob.job_id == job_id)
|
| 122 |
+
result = await session.execute(query)
|
| 123 |
+
job = result.scalar_one_or_none()
|
| 124 |
+
|
| 125 |
+
if not job:
|
| 126 |
+
logger.error(f"Job {job_id} not found")
|
| 127 |
+
return
|
| 128 |
+
|
| 129 |
+
logger.info(f"Processing job {job_id} (type: {job.job_type})")
|
| 130 |
+
|
| 131 |
+
service = GeminiService()
|
| 132 |
+
input_data = job.input_data or {}
|
| 133 |
+
|
| 134 |
+
if job.job_type == "video":
|
| 135 |
+
await self._process_video_job(session, job, service, input_data)
|
| 136 |
+
elif job.job_type == "image":
|
| 137 |
+
await self._process_image_job(session, job, service, input_data)
|
| 138 |
+
elif job.job_type == "text":
|
| 139 |
+
await self._process_text_job(session, job, service, input_data)
|
| 140 |
+
elif job.job_type == "analyze":
|
| 141 |
+
await self._process_analyze_job(session, job, service, input_data)
|
| 142 |
+
elif job.job_type == "animation_prompt":
|
| 143 |
+
await self._process_animation_prompt_job(session, job, service, input_data)
|
| 144 |
+
else:
|
| 145 |
+
job.status = "failed"
|
| 146 |
+
job.error_message = f"Unknown job type: {job.job_type}"
|
| 147 |
+
job.completed_at = datetime.utcnow()
|
| 148 |
+
await session.commit()
|
| 149 |
+
|
| 150 |
+
except Exception as e:
|
| 151 |
+
logger.error(f"Error processing job {job_id}: {e}")
|
| 152 |
+
try:
|
| 153 |
+
job.status = "failed"
|
| 154 |
+
job.error_message = str(e)
|
| 155 |
+
job.completed_at = datetime.utcnow()
|
| 156 |
+
await session.commit()
|
| 157 |
+
except:
|
| 158 |
+
pass
|
| 159 |
+
finally:
|
| 160 |
+
async with _lock:
|
| 161 |
+
job_type = job.job_type if job else "unknown"
|
| 162 |
+
_running_jobs[job_type] = max(0, _running_jobs.get(job_type, 0) - 1)
|
| 163 |
+
|
| 164 |
+
async def _process_video_job(self, session: AsyncSession, job: GeminiJob, service: GeminiService, input_data: dict):
|
| 165 |
+
"""Process a video generation job."""
|
| 166 |
+
# Start video generation
|
| 167 |
+
result = await service.start_video_generation(
|
| 168 |
+
base64_image=input_data.get("base64_image", ""),
|
| 169 |
+
mime_type=input_data.get("mime_type", "image/jpeg"),
|
| 170 |
+
prompt=input_data.get("prompt", ""),
|
| 171 |
+
aspect_ratio=input_data.get("aspect_ratio", "16:9"),
|
| 172 |
+
resolution=input_data.get("resolution", "720p"),
|
| 173 |
+
number_of_videos=input_data.get("number_of_videos", 1)
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
# Save third party ID
|
| 177 |
+
job.third_party_id = result.get("gemini_operation_name")
|
| 178 |
+
await session.commit()
|
| 179 |
+
|
| 180 |
+
# Poll until done
|
| 181 |
+
while True:
|
| 182 |
+
status_result = await service.check_video_status(job.third_party_id)
|
| 183 |
+
|
| 184 |
+
if status_result.get("done"):
|
| 185 |
+
if status_result.get("status") == "completed":
|
| 186 |
+
# Download video
|
| 187 |
+
video_url = status_result.get("video_url")
|
| 188 |
+
if video_url:
|
| 189 |
+
filename = await service.download_video(video_url, job.job_id)
|
| 190 |
+
job.status = "completed"
|
| 191 |
+
job.output_data = {"filename": filename}
|
| 192 |
+
else:
|
| 193 |
+
job.status = "failed"
|
| 194 |
+
job.error_message = "No video URL returned"
|
| 195 |
+
else:
|
| 196 |
+
job.status = "failed"
|
| 197 |
+
job.error_message = status_result.get("error", "Unknown error")
|
| 198 |
+
|
| 199 |
+
job.completed_at = datetime.utcnow()
|
| 200 |
+
await session.commit()
|
| 201 |
+
break
|
| 202 |
+
|
| 203 |
+
await asyncio.sleep(10) # Poll every 10 seconds
|
| 204 |
+
|
| 205 |
+
async def _process_image_job(self, session: AsyncSession, job: GeminiJob, service: GeminiService, input_data: dict):
|
| 206 |
+
"""Process an image edit job."""
|
| 207 |
+
result = await service.edit_image(
|
| 208 |
+
base64_image=input_data.get("base64_image", ""),
|
| 209 |
+
mime_type=input_data.get("mime_type", "image/jpeg"),
|
| 210 |
+
prompt=input_data.get("prompt", "")
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
job.status = "completed"
|
| 214 |
+
job.output_data = {"image": result}
|
| 215 |
+
job.completed_at = datetime.utcnow()
|
| 216 |
+
await session.commit()
|
| 217 |
+
|
| 218 |
+
async def _process_text_job(self, session: AsyncSession, job: GeminiJob, service: GeminiService, input_data: dict):
|
| 219 |
+
"""Process a text generation job."""
|
| 220 |
+
result = await service.generate_text(
|
| 221 |
+
prompt=input_data.get("prompt", ""),
|
| 222 |
+
model=input_data.get("model")
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
job.status = "completed"
|
| 226 |
+
job.output_data = {"text": result}
|
| 227 |
+
job.completed_at = datetime.utcnow()
|
| 228 |
+
await session.commit()
|
| 229 |
+
|
| 230 |
+
async def _process_analyze_job(self, session: AsyncSession, job: GeminiJob, service: GeminiService, input_data: dict):
|
| 231 |
+
"""Process an image analysis job."""
|
| 232 |
+
result = await service.analyze_image(
|
| 233 |
+
base64_image=input_data.get("base64_image", ""),
|
| 234 |
+
mime_type=input_data.get("mime_type", "image/jpeg"),
|
| 235 |
+
prompt=input_data.get("prompt", "")
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
job.status = "completed"
|
| 239 |
+
job.output_data = {"analysis": result}
|
| 240 |
+
job.completed_at = datetime.utcnow()
|
| 241 |
+
await session.commit()
|
| 242 |
+
|
| 243 |
+
async def _process_animation_prompt_job(self, session: AsyncSession, job: GeminiJob, service: GeminiService, input_data: dict):
|
| 244 |
+
"""Process an animation prompt generation job."""
|
| 245 |
+
result = await service.generate_animation_prompt(
|
| 246 |
+
base64_image=input_data.get("base64_image", ""),
|
| 247 |
+
mime_type=input_data.get("mime_type", "image/jpeg"),
|
| 248 |
+
custom_prompt=input_data.get("custom_prompt")
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
job.status = "completed"
|
| 252 |
+
job.output_data = {"prompt": result}
|
| 253 |
+
job.completed_at = datetime.utcnow()
|
| 254 |
+
await session.commit()
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
# Singleton worker instance
|
| 258 |
+
_worker: JobWorker = None
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def get_worker() -> JobWorker:
|
| 262 |
+
"""Get the global worker instance."""
|
| 263 |
+
global _worker
|
| 264 |
+
if _worker is None:
|
| 265 |
+
_worker = JobWorker()
|
| 266 |
+
return _worker
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
async def start_worker():
|
| 270 |
+
"""Start the background worker."""
|
| 271 |
+
worker = get_worker()
|
| 272 |
+
await worker.start()
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
async def stop_worker():
|
| 276 |
+
"""Stop the background worker."""
|
| 277 |
+
worker = get_worker()
|
| 278 |
+
await worker.stop()
|
services/jwt_service.py
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Modular JWT Service
|
| 3 |
+
|
| 4 |
+
A self-contained, plug-and-play service for creating and verifying JWT tokens.
|
| 5 |
+
Can be used in any Python application with minimal configuration.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
from services.jwt_service import JWTService, TokenPayload
|
| 9 |
+
|
| 10 |
+
# Initialize with secret key
|
| 11 |
+
jwt_service = JWTService(secret_key="your-secret-key")
|
| 12 |
+
|
| 13 |
+
# Or use environment variable JWT_SECRET
|
| 14 |
+
jwt_service = JWTService()
|
| 15 |
+
|
| 16 |
+
# Create a token
|
| 17 |
+
token = jwt_service.create_token(user_id="user123", email="[email protected]")
|
| 18 |
+
|
| 19 |
+
# Verify a token
|
| 20 |
+
payload = jwt_service.verify_token(token)
|
| 21 |
+
print(payload.user_id, payload.email)
|
| 22 |
+
|
| 23 |
+
Environment Variables:
|
| 24 |
+
JWT_SECRET: Your secret key for signing tokens (required)
|
| 25 |
+
JWT_EXPIRY_HOURS: Token expiry in hours (default: 168 = 7 days)
|
| 26 |
+
JWT_ALGORITHM: Algorithm to use (default: HS256)
|
| 27 |
+
|
| 28 |
+
Dependencies:
|
| 29 |
+
PyJWT>=2.8.0
|
| 30 |
+
|
| 31 |
+
Generate a secure secret:
|
| 32 |
+
python -c "import secrets; print(secrets.token_urlsafe(64))"
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
import os
|
| 36 |
+
import logging
|
| 37 |
+
from dataclasses import dataclass
|
| 38 |
+
from datetime import datetime, timedelta
|
| 39 |
+
from typing import Optional, Dict, Any
|
| 40 |
+
import jwt
|
| 41 |
+
|
| 42 |
+
logger = logging.getLogger(__name__)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@dataclass
|
| 46 |
+
class TokenPayload:
|
| 47 |
+
"""
|
| 48 |
+
Payload extracted from a verified JWT token.
|
| 49 |
+
|
| 50 |
+
Attributes:
|
| 51 |
+
user_id: The user's unique identifier (sub claim)
|
| 52 |
+
email: The user's email address
|
| 53 |
+
issued_at: When the token was issued
|
| 54 |
+
expires_at: When the token expires
|
| 55 |
+
extra: Any additional claims in the token
|
| 56 |
+
"""
|
| 57 |
+
user_id: str
|
| 58 |
+
email: str
|
| 59 |
+
issued_at: datetime
|
| 60 |
+
expires_at: datetime
|
| 61 |
+
extra: Dict[str, Any] = None
|
| 62 |
+
|
| 63 |
+
def __post_init__(self):
|
| 64 |
+
if self.extra is None:
|
| 65 |
+
self.extra = {}
|
| 66 |
+
|
| 67 |
+
@property
|
| 68 |
+
def is_expired(self) -> bool:
|
| 69 |
+
"""Check if the token has expired."""
|
| 70 |
+
return datetime.utcnow() > self.expires_at
|
| 71 |
+
|
| 72 |
+
@property
|
| 73 |
+
def time_until_expiry(self) -> timedelta:
|
| 74 |
+
"""Get time remaining until expiry."""
|
| 75 |
+
return self.expires_at - datetime.utcnow()
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class JWTError(Exception):
|
| 79 |
+
"""Base exception for JWT errors."""
|
| 80 |
+
pass
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class TokenExpiredError(JWTError):
|
| 84 |
+
"""Raised when the token has expired."""
|
| 85 |
+
pass
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class InvalidTokenError(JWTError):
|
| 89 |
+
"""Raised when the token is invalid."""
|
| 90 |
+
pass
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class ConfigurationError(JWTError):
|
| 94 |
+
"""Raised when the service is not properly configured."""
|
| 95 |
+
pass
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class JWTService:
|
| 99 |
+
"""
|
| 100 |
+
Service for creating and verifying JWT tokens.
|
| 101 |
+
|
| 102 |
+
This service handles JWT token lifecycle for authentication.
|
| 103 |
+
It's designed to be modular and reusable across different applications.
|
| 104 |
+
|
| 105 |
+
Example:
|
| 106 |
+
service = JWTService(secret_key="my-secret")
|
| 107 |
+
|
| 108 |
+
# Create token
|
| 109 |
+
token = service.create_token(user_id="u123", email="[email protected]")
|
| 110 |
+
|
| 111 |
+
# Verify token
|
| 112 |
+
try:
|
| 113 |
+
payload = service.verify_token(token)
|
| 114 |
+
print(f"User: {payload.user_id}")
|
| 115 |
+
except TokenExpiredError:
|
| 116 |
+
print("Token expired, please login again")
|
| 117 |
+
except InvalidTokenError:
|
| 118 |
+
print("Invalid token")
|
| 119 |
+
"""
|
| 120 |
+
|
| 121 |
+
# Default configuration
|
| 122 |
+
DEFAULT_ALGORITHM = "HS256"
|
| 123 |
+
DEFAULT_EXPIRY_HOURS = 168 # 7 days
|
| 124 |
+
|
| 125 |
+
def __init__(
|
| 126 |
+
self,
|
| 127 |
+
secret_key: Optional[str] = None,
|
| 128 |
+
algorithm: Optional[str] = None,
|
| 129 |
+
expiry_hours: Optional[int] = None
|
| 130 |
+
):
|
| 131 |
+
"""
|
| 132 |
+
Initialize the JWT Service.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
secret_key: Secret key for signing tokens. If not provided,
|
| 136 |
+
falls back to JWT_SECRET environment variable.
|
| 137 |
+
algorithm: JWT algorithm (default: HS256).
|
| 138 |
+
expiry_hours: Token expiry in hours (default: 168 = 7 days).
|
| 139 |
+
|
| 140 |
+
Raises:
|
| 141 |
+
ConfigurationError: If no secret_key is provided or found.
|
| 142 |
+
"""
|
| 143 |
+
self.secret_key = secret_key or os.getenv("JWT_SECRET")
|
| 144 |
+
self.algorithm = algorithm or os.getenv("JWT_ALGORITHM", self.DEFAULT_ALGORITHM)
|
| 145 |
+
self.expiry_hours = expiry_hours or int(
|
| 146 |
+
os.getenv("JWT_EXPIRY_HOURS", str(self.DEFAULT_EXPIRY_HOURS))
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
if not self.secret_key:
|
| 150 |
+
raise ConfigurationError(
|
| 151 |
+
"JWT secret key is required. Either pass secret_key parameter "
|
| 152 |
+
"or set JWT_SECRET environment variable. "
|
| 153 |
+
"Generate one with: python -c \"import secrets; print(secrets.token_urlsafe(64))\""
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
# Warn if secret is too short
|
| 157 |
+
if len(self.secret_key) < 32:
|
| 158 |
+
logger.warning(
|
| 159 |
+
"JWT secret key is short (< 32 chars). "
|
| 160 |
+
"Consider using a longer secret for better security."
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
logger.info(
|
| 164 |
+
f"JWTService initialized (algorithm={self.algorithm}, "
|
| 165 |
+
f"expiry={self.expiry_hours}h)"
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
def create_token(
|
| 169 |
+
self,
|
| 170 |
+
user_id: str,
|
| 171 |
+
email: str,
|
| 172 |
+
extra_claims: Optional[Dict[str, Any]] = None,
|
| 173 |
+
expiry_hours: Optional[int] = None
|
| 174 |
+
) -> str:
|
| 175 |
+
"""
|
| 176 |
+
Create a JWT token for a user.
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
user_id: The user's unique identifier.
|
| 180 |
+
email: The user's email address.
|
| 181 |
+
extra_claims: Additional claims to include in the token.
|
| 182 |
+
expiry_hours: Custom expiry for this token (overrides default).
|
| 183 |
+
|
| 184 |
+
Returns:
|
| 185 |
+
str: The encoded JWT token.
|
| 186 |
+
"""
|
| 187 |
+
now = datetime.utcnow()
|
| 188 |
+
expiry = expiry_hours or self.expiry_hours
|
| 189 |
+
|
| 190 |
+
payload = {
|
| 191 |
+
"sub": user_id,
|
| 192 |
+
"email": email,
|
| 193 |
+
"iat": now,
|
| 194 |
+
"exp": now + timedelta(hours=expiry),
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
if extra_claims:
|
| 198 |
+
payload.update(extra_claims)
|
| 199 |
+
|
| 200 |
+
token = jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
|
| 201 |
+
|
| 202 |
+
logger.debug(f"Created token for user_id={user_id}")
|
| 203 |
+
return token
|
| 204 |
+
|
| 205 |
+
def verify_token(self, token: str) -> TokenPayload:
|
| 206 |
+
"""
|
| 207 |
+
Verify a JWT token and extract the payload.
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
token: The JWT token to verify.
|
| 211 |
+
|
| 212 |
+
Returns:
|
| 213 |
+
TokenPayload: Dataclass containing the verified payload.
|
| 214 |
+
|
| 215 |
+
Raises:
|
| 216 |
+
TokenExpiredError: If the token has expired.
|
| 217 |
+
InvalidTokenError: If the token is invalid or malformed.
|
| 218 |
+
"""
|
| 219 |
+
if not token:
|
| 220 |
+
raise InvalidTokenError("Token cannot be empty")
|
| 221 |
+
|
| 222 |
+
try:
|
| 223 |
+
payload = jwt.decode(
|
| 224 |
+
token,
|
| 225 |
+
self.secret_key,
|
| 226 |
+
algorithms=[self.algorithm]
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
# Extract standard claims
|
| 230 |
+
user_id = payload.get("sub")
|
| 231 |
+
email = payload.get("email")
|
| 232 |
+
iat = payload.get("iat")
|
| 233 |
+
exp = payload.get("exp")
|
| 234 |
+
|
| 235 |
+
if not user_id or not email:
|
| 236 |
+
raise InvalidTokenError("Token missing required claims (sub, email)")
|
| 237 |
+
|
| 238 |
+
# Convert timestamps to datetime
|
| 239 |
+
issued_at = datetime.utcfromtimestamp(iat) if isinstance(iat, (int, float)) else iat
|
| 240 |
+
expires_at = datetime.utcfromtimestamp(exp) if isinstance(exp, (int, float)) else exp
|
| 241 |
+
|
| 242 |
+
# Extract extra claims
|
| 243 |
+
standard_claims = {"sub", "email", "iat", "exp"}
|
| 244 |
+
extra = {k: v for k, v in payload.items() if k not in standard_claims}
|
| 245 |
+
|
| 246 |
+
return TokenPayload(
|
| 247 |
+
user_id=user_id,
|
| 248 |
+
email=email,
|
| 249 |
+
issued_at=issued_at,
|
| 250 |
+
expires_at=expires_at,
|
| 251 |
+
extra=extra
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
except jwt.ExpiredSignatureError:
|
| 255 |
+
logger.debug("Token verification failed: expired")
|
| 256 |
+
raise TokenExpiredError("Token has expired")
|
| 257 |
+
except jwt.InvalidTokenError as e:
|
| 258 |
+
logger.debug(f"Token verification failed: {e}")
|
| 259 |
+
raise InvalidTokenError(f"Invalid token: {str(e)}")
|
| 260 |
+
except Exception as e:
|
| 261 |
+
logger.error(f"Unexpected error during token verification: {e}")
|
| 262 |
+
raise InvalidTokenError(f"Token verification error: {str(e)}")
|
| 263 |
+
|
| 264 |
+
def verify_token_safe(self, token: str) -> Optional[TokenPayload]:
|
| 265 |
+
"""
|
| 266 |
+
Verify a JWT token without raising exceptions.
|
| 267 |
+
|
| 268 |
+
Args:
|
| 269 |
+
token: The JWT token to verify.
|
| 270 |
+
|
| 271 |
+
Returns:
|
| 272 |
+
TokenPayload if valid, None if invalid or expired.
|
| 273 |
+
"""
|
| 274 |
+
try:
|
| 275 |
+
return self.verify_token(token)
|
| 276 |
+
except JWTError:
|
| 277 |
+
return None
|
| 278 |
+
|
| 279 |
+
def refresh_token(
|
| 280 |
+
self,
|
| 281 |
+
token: str,
|
| 282 |
+
expiry_hours: Optional[int] = None
|
| 283 |
+
) -> str:
|
| 284 |
+
"""
|
| 285 |
+
Refresh a token by creating a new one with the same claims.
|
| 286 |
+
|
| 287 |
+
Args:
|
| 288 |
+
token: The current (possibly expired) token.
|
| 289 |
+
expiry_hours: Custom expiry for the new token.
|
| 290 |
+
|
| 291 |
+
Returns:
|
| 292 |
+
str: A new JWT token with updated expiry.
|
| 293 |
+
|
| 294 |
+
Raises:
|
| 295 |
+
InvalidTokenError: If the token is malformed.
|
| 296 |
+
"""
|
| 297 |
+
try:
|
| 298 |
+
# Decode without verifying expiry
|
| 299 |
+
payload = jwt.decode(
|
| 300 |
+
token,
|
| 301 |
+
self.secret_key,
|
| 302 |
+
algorithms=[self.algorithm],
|
| 303 |
+
options={"verify_exp": False}
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
user_id = payload.get("sub")
|
| 307 |
+
email = payload.get("email")
|
| 308 |
+
|
| 309 |
+
if not user_id or not email:
|
| 310 |
+
raise InvalidTokenError("Token missing required claims")
|
| 311 |
+
|
| 312 |
+
# Preserve extra claims
|
| 313 |
+
standard_claims = {"sub", "email", "iat", "exp"}
|
| 314 |
+
extra = {k: v for k, v in payload.items() if k not in standard_claims}
|
| 315 |
+
|
| 316 |
+
return self.create_token(
|
| 317 |
+
user_id=user_id,
|
| 318 |
+
email=email,
|
| 319 |
+
extra_claims=extra,
|
| 320 |
+
expiry_hours=expiry_hours
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
except jwt.InvalidTokenError as e:
|
| 324 |
+
raise InvalidTokenError(f"Cannot refresh invalid token: {str(e)}")
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
# Singleton instance for convenience
|
| 328 |
+
_default_service: Optional[JWTService] = None
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def get_jwt_service() -> JWTService:
|
| 332 |
+
"""
|
| 333 |
+
Get the default JWTService instance.
|
| 334 |
+
|
| 335 |
+
Creates a singleton instance using environment variables.
|
| 336 |
+
|
| 337 |
+
Returns:
|
| 338 |
+
JWTService: The default service instance.
|
| 339 |
+
|
| 340 |
+
Raises:
|
| 341 |
+
ConfigurationError: If JWT_SECRET is not set.
|
| 342 |
+
"""
|
| 343 |
+
global _default_service
|
| 344 |
+
if _default_service is None:
|
| 345 |
+
_default_service = JWTService()
|
| 346 |
+
return _default_service
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
def create_access_token(user_id: str, email: str, **kwargs) -> str:
|
| 350 |
+
"""
|
| 351 |
+
Convenience function to create a token using the default service.
|
| 352 |
+
|
| 353 |
+
Args:
|
| 354 |
+
user_id: The user's unique identifier.
|
| 355 |
+
email: The user's email address.
|
| 356 |
+
**kwargs: Additional arguments passed to create_token.
|
| 357 |
+
|
| 358 |
+
Returns:
|
| 359 |
+
str: The encoded JWT token.
|
| 360 |
+
"""
|
| 361 |
+
return get_jwt_service().create_token(user_id, email, **kwargs)
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def verify_access_token(token: str) -> TokenPayload:
|
| 365 |
+
"""
|
| 366 |
+
Convenience function to verify a token using the default service.
|
| 367 |
+
|
| 368 |
+
Args:
|
| 369 |
+
token: The JWT token to verify.
|
| 370 |
+
|
| 371 |
+
Returns:
|
| 372 |
+
TokenPayload: Verified token payload.
|
| 373 |
+
|
| 374 |
+
Raises:
|
| 375 |
+
TokenExpiredError: If the token has expired.
|
| 376 |
+
InvalidTokenError: If the token is invalid.
|
| 377 |
+
"""
|
| 378 |
+
return get_jwt_service().verify_token(token)
|
tests/conftest.py
CHANGED
|
@@ -1,14 +1,27 @@
|
|
| 1 |
import pytest
|
| 2 |
import os
|
| 3 |
import sys
|
|
|
|
| 4 |
from fastapi.testclient import TestClient
|
| 5 |
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
# Add parent directory to path to allow importing app
|
| 8 |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
| 9 |
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
# Use a file-based SQLite database for testing to ensure persistence
|
| 14 |
TEST_DATABASE_URL = "sqlite+aiosqlite:///./test_blink_data.db"
|
|
@@ -34,16 +47,6 @@ async def db_session(test_engine):
|
|
| 34 |
|
| 35 |
async with async_session() as session:
|
| 36 |
yield session
|
| 37 |
-
|
| 38 |
-
# We don't drop tables here to allow persistence if needed,
|
| 39 |
-
# but for isolation we usually want to.
|
| 40 |
-
# However, the previous test relied on persistence for rate limiting.
|
| 41 |
-
# Let's keep it simple: we clear data manually if needed or rely on fresh DB per run (session scope engine).
|
| 42 |
-
# Actually, for rate limiting test to work across requests, we need persistence.
|
| 43 |
-
# But for isolation between tests, we want cleanup.
|
| 44 |
-
# The previous test_app.py had cleanup_db fixture. Let's replicate that logic in the test file or here.
|
| 45 |
-
|
| 46 |
-
# Let's just yield session here.
|
| 47 |
|
| 48 |
@pytest.fixture(scope="function")
|
| 49 |
def client(test_engine):
|
|
@@ -57,6 +60,12 @@ def client(test_engine):
|
|
| 57 |
yield session
|
| 58 |
|
| 59 |
app.dependency_overrides[get_db] = override_get_db
|
| 60 |
-
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
app.dependency_overrides.clear()
|
|
|
|
|
|
| 1 |
import pytest
|
| 2 |
import os
|
| 3 |
import sys
|
| 4 |
+
from unittest.mock import patch, MagicMock
|
| 5 |
from fastapi.testclient import TestClient
|
| 6 |
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
|
| 7 |
|
| 8 |
+
# Set test environment variables BEFORE importing app
|
| 9 |
+
os.environ["JWT_SECRET"] = "test-secret-key-that-is-long-enough-for-security-purposes"
|
| 10 |
+
os.environ["GOOGLE_CLIENT_ID"] = "test-google-client-id.apps.googleusercontent.com"
|
| 11 |
+
os.environ["RESET_DB"] = "true" # Prevent Drive download during tests
|
| 12 |
+
|
| 13 |
# Add parent directory to path to allow importing app
|
| 14 |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
| 15 |
|
| 16 |
+
# Mock the drive service before importing app
|
| 17 |
+
with patch("services.drive_service.DriveService") as mock_drive:
|
| 18 |
+
mock_instance = MagicMock()
|
| 19 |
+
mock_instance.download_db.return_value = False
|
| 20 |
+
mock_instance.upload_db.return_value = True
|
| 21 |
+
mock_drive.return_value = mock_instance
|
| 22 |
+
|
| 23 |
+
from app import app
|
| 24 |
+
from core.database import get_db, Base
|
| 25 |
|
| 26 |
# Use a file-based SQLite database for testing to ensure persistence
|
| 27 |
TEST_DATABASE_URL = "sqlite+aiosqlite:///./test_blink_data.db"
|
|
|
|
| 47 |
|
| 48 |
async with async_session() as session:
|
| 49 |
yield session
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
@pytest.fixture(scope="function")
|
| 52 |
def client(test_engine):
|
|
|
|
| 60 |
yield session
|
| 61 |
|
| 62 |
app.dependency_overrides[get_db] = override_get_db
|
| 63 |
+
|
| 64 |
+
# Mock drive service for the test client
|
| 65 |
+
with patch("routers.auth.drive_service") as mock_auth_drive:
|
| 66 |
+
mock_auth_drive.upload_db.return_value = True
|
| 67 |
+
with TestClient(app) as c:
|
| 68 |
+
yield c
|
| 69 |
+
|
| 70 |
app.dependency_overrides.clear()
|
| 71 |
+
|
tests/debug_gemini_service.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Debug script to test Gemini service with API keys from environment.
|
| 3 |
+
Keys should be in GEMINI_KEYS environment variable, comma-separated.
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
GEMINI_KEYS="key1,key2,key3" python tests/debug_gemini_service.py
|
| 7 |
+
"""
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
import asyncio
|
| 11 |
+
import logging
|
| 12 |
+
import base64
|
| 13 |
+
from dotenv import load_dotenv
|
| 14 |
+
|
| 15 |
+
# Load environment variables
|
| 16 |
+
load_dotenv()
|
| 17 |
+
|
| 18 |
+
# Add parent directory to path
|
| 19 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
| 20 |
+
|
| 21 |
+
from services.gemini_service import GeminiService, MODELS
|
| 22 |
+
|
| 23 |
+
# Configure logging
|
| 24 |
+
logging.basicConfig(level=logging.INFO)
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
# Test image path
|
| 28 |
+
TEST_IMAGE_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "test.jpg")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def load_test_image():
|
| 32 |
+
"""Load test image and return base64 + mime type."""
|
| 33 |
+
if not os.path.exists(TEST_IMAGE_PATH):
|
| 34 |
+
logger.error(f"Test image not found: {TEST_IMAGE_PATH}")
|
| 35 |
+
return None, None
|
| 36 |
+
|
| 37 |
+
with open(TEST_IMAGE_PATH, "rb") as f:
|
| 38 |
+
image_data = f.read()
|
| 39 |
+
|
| 40 |
+
base64_image = base64.b64encode(image_data).decode("utf-8")
|
| 41 |
+
mime_type = "image/jpeg"
|
| 42 |
+
|
| 43 |
+
logger.info(f"Loaded test image: {TEST_IMAGE_PATH} ({len(image_data)} bytes)")
|
| 44 |
+
return base64_image, mime_type
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
async def test_generate_text(service: GeminiService, key_index: int):
|
| 48 |
+
"""Test simple text generation."""
|
| 49 |
+
logger.info(f"[Key {key_index}] Testing text generation...")
|
| 50 |
+
try:
|
| 51 |
+
result = await service.generate_text("Say hello in one word.")
|
| 52 |
+
logger.info(f"[Key {key_index}] Text generation result: {result[:100]}...")
|
| 53 |
+
return True
|
| 54 |
+
except Exception as e:
|
| 55 |
+
logger.error(f"[Key {key_index}] Text generation failed: {e}")
|
| 56 |
+
return False
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
async def test_analyze_image(service: GeminiService, key_index: int, base64_image: str, mime_type: str):
|
| 60 |
+
"""Test image analysis."""
|
| 61 |
+
logger.info(f"[Key {key_index}] Testing image analysis...")
|
| 62 |
+
try:
|
| 63 |
+
result = await service.analyze_image(
|
| 64 |
+
base64_image=base64_image,
|
| 65 |
+
mime_type=mime_type,
|
| 66 |
+
prompt="Describe this image in one sentence."
|
| 67 |
+
)
|
| 68 |
+
logger.info(f"[Key {key_index}] Image analysis result: {result[:100]}...")
|
| 69 |
+
return True
|
| 70 |
+
except Exception as e:
|
| 71 |
+
logger.error(f"[Key {key_index}] Image analysis failed: {e}")
|
| 72 |
+
return False
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
async def test_generate_animation_prompt(service: GeminiService, key_index: int, base64_image: str, mime_type: str):
|
| 76 |
+
"""Test animation prompt generation."""
|
| 77 |
+
logger.info(f"[Key {key_index}] Testing animation prompt generation...")
|
| 78 |
+
try:
|
| 79 |
+
result = await service.generate_animation_prompt(
|
| 80 |
+
base64_image=base64_image,
|
| 81 |
+
mime_type=mime_type
|
| 82 |
+
)
|
| 83 |
+
logger.info(f"[Key {key_index}] Animation prompt result: {result[:100]}...")
|
| 84 |
+
return True
|
| 85 |
+
except Exception as e:
|
| 86 |
+
logger.error(f"[Key {key_index}] Animation prompt generation failed: {e}")
|
| 87 |
+
return False
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
async def test_key(api_key: str, key_index: int, base64_image: str, mime_type: str):
|
| 91 |
+
"""Test all basic operations with a single API key."""
|
| 92 |
+
logger.info(f"\n{'='*50}")
|
| 93 |
+
logger.info(f"Testing Key {key_index}: {api_key[:10]}...{api_key[-4:]}")
|
| 94 |
+
logger.info(f"{'='*50}")
|
| 95 |
+
|
| 96 |
+
try:
|
| 97 |
+
service = GeminiService(api_key)
|
| 98 |
+
except Exception as e:
|
| 99 |
+
logger.error(f"[Key {key_index}] Failed to initialize service: {e}")
|
| 100 |
+
return {"key_index": key_index, "valid": False, "error": str(e)}
|
| 101 |
+
|
| 102 |
+
results = {
|
| 103 |
+
"key_index": key_index,
|
| 104 |
+
"key_preview": f"{api_key[:10]}...{api_key[-4:]}",
|
| 105 |
+
"text_generation": await test_generate_text(service, key_index),
|
| 106 |
+
"image_analysis": await test_analyze_image(service, key_index, base64_image, mime_type),
|
| 107 |
+
"animation_prompt": await test_generate_animation_prompt(service, key_index, base64_image, mime_type),
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
results["valid"] = all([
|
| 111 |
+
results["text_generation"],
|
| 112 |
+
results["image_analysis"],
|
| 113 |
+
results["animation_prompt"]
|
| 114 |
+
])
|
| 115 |
+
|
| 116 |
+
return results
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
async def main():
|
| 120 |
+
# Load test image
|
| 121 |
+
base64_image, mime_type = load_test_image()
|
| 122 |
+
if not base64_image:
|
| 123 |
+
logger.error("Cannot run tests without test image. Please add test.jpg to project root.")
|
| 124 |
+
return
|
| 125 |
+
|
| 126 |
+
gemini_keys_str = os.getenv("GEMINI_KEYS", "")
|
| 127 |
+
|
| 128 |
+
if not gemini_keys_str:
|
| 129 |
+
logger.error("GEMINI_KEYS environment variable not set.")
|
| 130 |
+
logger.info("Usage: GEMINI_KEYS='key1,key2,key3' python tests/debug_gemini_service.py")
|
| 131 |
+
return
|
| 132 |
+
|
| 133 |
+
keys = [k.strip() for k in gemini_keys_str.split(",") if k.strip()]
|
| 134 |
+
|
| 135 |
+
if not keys:
|
| 136 |
+
logger.error("No valid keys found in GEMINI_KEYS.")
|
| 137 |
+
return
|
| 138 |
+
|
| 139 |
+
logger.info(f"Found {len(keys)} API key(s) to test.")
|
| 140 |
+
logger.info(f"Available models: {MODELS}")
|
| 141 |
+
|
| 142 |
+
all_results = []
|
| 143 |
+
for i, key in enumerate(keys):
|
| 144 |
+
result = await test_key(key, i + 1, base64_image, mime_type)
|
| 145 |
+
all_results.append(result)
|
| 146 |
+
|
| 147 |
+
# Summary
|
| 148 |
+
logger.info(f"\n{'='*50}")
|
| 149 |
+
logger.info("SUMMARY")
|
| 150 |
+
logger.info(f"{'='*50}")
|
| 151 |
+
|
| 152 |
+
valid_count = sum(1 for r in all_results if r.get("valid", False))
|
| 153 |
+
logger.info(f"Valid keys: {valid_count}/{len(keys)}")
|
| 154 |
+
|
| 155 |
+
for result in all_results:
|
| 156 |
+
status = "✓ VALID" if result.get("valid") else "✗ INVALID"
|
| 157 |
+
logger.info(f" Key {result['key_index']}: {status}")
|
| 158 |
+
if not result.get("valid"):
|
| 159 |
+
for test_name in ["text_generation", "image_analysis", "animation_prompt"]:
|
| 160 |
+
if test_name in result and not result[test_name]:
|
| 161 |
+
logger.info(f" - {test_name}: FAILED")
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
if __name__ == "__main__":
|
| 165 |
+
asyncio.run(main())
|
tests/test_integration.py
CHANGED
|
@@ -1,25 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import pytest
|
| 2 |
-
from unittest.mock import patch
|
| 3 |
import os
|
| 4 |
from sqlalchemy import text
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
# Cleanup fixture
|
| 7 |
@pytest.fixture(autouse=True)
|
| 8 |
def cleanup_db():
|
| 9 |
if os.path.exists("./test_blink_data.db"):
|
| 10 |
-
# We can't easily delete the file if it's open by the engine in conftest.
|
| 11 |
-
# Instead, we should probably truncate tables.
|
| 12 |
pass
|
| 13 |
yield
|
| 14 |
-
# Cleanup logic if needed
|
| 15 |
|
| 16 |
-
# We need a way to clear data between tests if we want isolation.
|
| 17 |
-
# Since we are using a file-based DB shared across the session (engine),
|
| 18 |
-
# we should truncate tables.
|
| 19 |
|
| 20 |
@pytest.fixture(autouse=True)
|
| 21 |
async def clear_tables(db_session):
|
| 22 |
-
|
| 23 |
async with db_session.begin():
|
| 24 |
await db_session.execute(text("DELETE FROM users"))
|
| 25 |
await db_session.execute(text("DELETE FROM rate_limits"))
|
|
@@ -27,65 +30,179 @@ async def clear_tables(db_session):
|
|
| 27 |
await db_session.execute(text("DELETE FROM blink_data"))
|
| 28 |
await db_session.commit()
|
| 29 |
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
"user_id": "test-user-2",
|
| 53 |
-
"email": "[email protected]"
|
| 54 |
})
|
| 55 |
|
| 56 |
-
# Validate
|
| 57 |
-
response = client.post("/auth/validate", headers={"X-Secret-Key": "sk_test_key_1234567890123456789012345"})
|
| 58 |
assert response.status_code == 200
|
| 59 |
-
|
| 60 |
-
assert
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
userid_param = user_id + encrypted_data
|
| 68 |
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
assert response.status_code == 200
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Integration Tests for Google OAuth Authentication
|
| 3 |
+
|
| 4 |
+
Tests the new Google Sign-In flow, JWT token handling, and API access.
|
| 5 |
+
"""
|
| 6 |
import pytest
|
| 7 |
+
from unittest.mock import patch, MagicMock
|
| 8 |
import os
|
| 9 |
from sqlalchemy import text
|
| 10 |
|
| 11 |
+
from services.google_auth_service import GoogleUserInfo
|
| 12 |
+
from services.jwt_service import JWTService
|
| 13 |
+
|
| 14 |
+
|
| 15 |
# Cleanup fixture
|
| 16 |
@pytest.fixture(autouse=True)
|
| 17 |
def cleanup_db():
|
| 18 |
if os.path.exists("./test_blink_data.db"):
|
|
|
|
|
|
|
| 19 |
pass
|
| 20 |
yield
|
|
|
|
| 21 |
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
@pytest.fixture(autouse=True)
|
| 24 |
async def clear_tables(db_session):
|
| 25 |
+
"""Truncate all tables between tests."""
|
| 26 |
async with db_session.begin():
|
| 27 |
await db_session.execute(text("DELETE FROM users"))
|
| 28 |
await db_session.execute(text("DELETE FROM rate_limits"))
|
|
|
|
| 30 |
await db_session.execute(text("DELETE FROM blink_data"))
|
| 31 |
await db_session.commit()
|
| 32 |
|
| 33 |
+
|
| 34 |
+
@pytest.fixture
|
| 35 |
+
def jwt_service():
|
| 36 |
+
"""Create a JWT service for testing."""
|
| 37 |
+
return JWTService(secret_key="test-secret-key-for-testing-only")
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@pytest.fixture
|
| 41 |
+
def mock_google_user():
|
| 42 |
+
"""Mock Google user info."""
|
| 43 |
+
return GoogleUserInfo(
|
| 44 |
+
google_id="google_123456789",
|
| 45 |
+
email="[email protected]",
|
| 46 |
+
email_verified=True,
|
| 47 |
+
name="Test User",
|
| 48 |
+
picture="https://example.com/photo.jpg"
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class TestGoogleAuth:
|
| 53 |
+
"""Test Google OAuth authentication flow."""
|
| 54 |
|
| 55 |
+
@patch("routers.auth.get_google_auth_service")
|
| 56 |
+
def test_google_auth_new_user(self, mock_get_service, client, mock_google_user):
|
| 57 |
+
"""Test new user registration via Google."""
|
| 58 |
+
mock_service = MagicMock()
|
| 59 |
+
mock_service.verify_token.return_value = mock_google_user
|
| 60 |
+
mock_get_service.return_value = mock_service
|
| 61 |
+
|
| 62 |
+
response = client.post("/auth/google", json={
|
| 63 |
+
"id_token": "fake-google-token-12345",
|
| 64 |
+
"temp_user_id": "temp-user-abc"
|
|
|
|
|
|
|
| 65 |
})
|
| 66 |
|
|
|
|
|
|
|
| 67 |
assert response.status_code == 200
|
| 68 |
+
data = response.json()
|
| 69 |
+
assert data["success"] == True
|
| 70 |
+
assert data["is_new_user"] == True
|
| 71 |
+
assert data["email"] == "[email protected]"
|
| 72 |
+
assert data["name"] == "Test User"
|
| 73 |
+
assert data["credits"] == 100
|
| 74 |
+
assert "access_token" in data
|
| 75 |
+
assert data["access_token"] != ""
|
|
|
|
| 76 |
|
| 77 |
+
@patch("routers.auth.get_google_auth_service")
|
| 78 |
+
def test_google_auth_existing_user(self, mock_get_service, client, mock_google_user):
|
| 79 |
+
"""Test existing user login via Google."""
|
| 80 |
+
mock_service = MagicMock()
|
| 81 |
+
mock_service.verify_token.return_value = mock_google_user
|
| 82 |
+
mock_get_service.return_value = mock_service
|
| 83 |
+
|
| 84 |
+
# First login - creates user
|
| 85 |
+
response1 = client.post("/auth/google", json={"id_token": "token1"})
|
| 86 |
+
assert response1.status_code == 200
|
| 87 |
+
assert response1.json()["is_new_user"] == True
|
| 88 |
+
|
| 89 |
+
# Second login - same user
|
| 90 |
+
response2 = client.post("/auth/google", json={"id_token": "token2"})
|
| 91 |
+
assert response2.status_code == 200
|
| 92 |
+
data = response2.json()
|
| 93 |
+
assert data["is_new_user"] == False
|
| 94 |
+
assert data["email"] == "[email protected]"
|
| 95 |
+
assert data["credits"] == 100 # Credits preserved
|
| 96 |
|
| 97 |
+
@patch("routers.auth.get_google_auth_service")
|
| 98 |
+
def test_google_auth_invalid_token(self, mock_get_service, client):
|
| 99 |
+
"""Test handling of invalid Google token."""
|
| 100 |
+
from services.google_auth_service import InvalidTokenError
|
| 101 |
+
|
| 102 |
+
mock_service = MagicMock()
|
| 103 |
+
mock_service.verify_token.side_effect = InvalidTokenError("Invalid token")
|
| 104 |
+
mock_get_service.return_value = mock_service
|
| 105 |
+
|
| 106 |
+
response = client.post("/auth/google", json={"id_token": "invalid-token"})
|
| 107 |
+
|
| 108 |
+
assert response.status_code == 401
|
| 109 |
+
assert "Invalid Google token" in response.json()["detail"]
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class TestJWTAuth:
|
| 113 |
+
"""Test JWT token authentication."""
|
| 114 |
+
|
| 115 |
+
@patch("routers.auth.get_google_auth_service")
|
| 116 |
+
def test_get_current_user(self, mock_get_service, client, mock_google_user):
|
| 117 |
+
"""Test getting current user with JWT."""
|
| 118 |
+
mock_service = MagicMock()
|
| 119 |
+
mock_service.verify_token.return_value = mock_google_user
|
| 120 |
+
mock_get_service.return_value = mock_service
|
| 121 |
+
|
| 122 |
+
# Login to get token
|
| 123 |
+
login_response = client.post("/auth/google", json={"id_token": "token"})
|
| 124 |
+
token = login_response.json()["access_token"]
|
| 125 |
+
|
| 126 |
+
# Get user info
|
| 127 |
+
response = client.get("/auth/me", headers={"Authorization": f"Bearer {token}"})
|
| 128 |
+
|
| 129 |
+
assert response.status_code == 200
|
| 130 |
+
data = response.json()
|
| 131 |
+
assert data["email"] == "[email protected]"
|
| 132 |
+
assert data["credits"] == 100
|
| 133 |
+
|
| 134 |
+
def test_missing_auth_header(self, client):
|
| 135 |
+
"""Test request without Authorization header."""
|
| 136 |
+
response = client.get("/auth/me")
|
| 137 |
+
assert response.status_code == 401
|
| 138 |
+
assert "Missing Authorization header" in response.json()["detail"]
|
| 139 |
+
|
| 140 |
+
def test_invalid_token_format(self, client):
|
| 141 |
+
"""Test request with invalid token format."""
|
| 142 |
+
response = client.get("/auth/me", headers={"Authorization": "InvalidFormat"})
|
| 143 |
+
assert response.status_code == 401
|
| 144 |
+
assert "Invalid Authorization header format" in response.json()["detail"]
|
| 145 |
+
|
| 146 |
+
def test_invalid_token(self, client):
|
| 147 |
+
"""Test request with invalid JWT token."""
|
| 148 |
+
response = client.get("/auth/me", headers={"Authorization": "Bearer invalid.jwt.token"})
|
| 149 |
+
assert response.status_code == 401
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class TestCreditSystem:
|
| 153 |
+
"""Test credit deduction system."""
|
| 154 |
+
|
| 155 |
+
@patch("routers.auth.get_google_auth_service")
|
| 156 |
+
def test_credit_deduction(self, mock_get_service, client, mock_google_user):
|
| 157 |
+
"""Test that credits are deducted when using API."""
|
| 158 |
+
mock_service = MagicMock()
|
| 159 |
+
mock_service.verify_token.return_value = mock_google_user
|
| 160 |
+
mock_get_service.return_value = mock_service
|
| 161 |
+
|
| 162 |
+
# Login
|
| 163 |
+
login_response = client.post("/auth/google", json={"id_token": "token"})
|
| 164 |
+
token = login_response.json()["access_token"]
|
| 165 |
+
initial_credits = login_response.json()["credits"]
|
| 166 |
+
|
| 167 |
+
# Make an API call that deducts credits (would need gemini endpoint mock)
|
| 168 |
+
# For now, just verify user info doesn't deduct credits
|
| 169 |
+
response = client.get("/auth/me", headers={"Authorization": f"Bearer {token}"})
|
| 170 |
+
assert response.json()["credits"] == initial_credits # No deduction for info endpoint
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class TestBlinkFlow:
|
| 174 |
+
"""Test blink data collection."""
|
| 175 |
+
|
| 176 |
+
def test_blink_flow(self, client):
|
| 177 |
+
"""Test Blink endpoint still works."""
|
| 178 |
+
user_id = "12345678901234567890"
|
| 179 |
+
encrypted_data = "some_encrypted_data_base64"
|
| 180 |
+
userid_param = user_id + encrypted_data
|
| 181 |
+
|
| 182 |
+
response = client.get(f"/blink?userid={userid_param}")
|
| 183 |
+
assert response.status_code == 200
|
| 184 |
+
data = response.json()
|
| 185 |
+
assert data["status"] == "success"
|
| 186 |
+
assert data["user_id"] == user_id
|
| 187 |
+
|
| 188 |
+
# Verify data stored
|
| 189 |
+
response = client.get("/api/data")
|
| 190 |
assert response.status_code == 200
|
| 191 |
+
items = response.json()["items"]
|
| 192 |
+
assert len(items) > 0
|
| 193 |
+
assert items[0]["user_id"] == user_id
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class TestRateLimiting:
|
| 197 |
+
"""Test rate limiting."""
|
| 198 |
+
|
| 199 |
+
def test_rate_limiting(self, client):
|
| 200 |
+
"""Test rate limiting on auth endpoints."""
|
| 201 |
+
# 10 requests should succeed
|
| 202 |
+
for _ in range(10):
|
| 203 |
+
response = client.post("/auth/check-registration", json={"user_id": "rate-limit-test"})
|
| 204 |
+
assert response.status_code == 200
|
| 205 |
|
| 206 |
+
# 11th request should fail
|
| 207 |
+
response = client.post("/auth/check-registration", json={"user_id": "rate-limit-test"})
|
| 208 |
+
assert response.status_code == 429
|