from fastapi import FastAPI, BackgroundTasks, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from typing import List, Optional import json import logging from datetime import datetime from email.utils import parsedate_to_datetime # Import our modules from scraper import fetch_hazard_tweets from classifier import classify_tweets from pg_db import init_db, upsert_hazardous_tweet # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize FastAPI app app = FastAPI( title="Ocean Hazard Detection API", description="API for detecting ocean hazards from social media posts", version="1.0.0" ) # CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], # Configure this properly for production allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Initialize database try: init_db() logger.info("Database initialized successfully") except Exception as e: logger.warning(f"Database initialization failed: {e}. API will work without database persistence.") # Pydantic models class TweetAnalysisRequest(BaseModel): limit: int = 20 query: Optional[str] = None class TweetAnalysisResponse(BaseModel): total_tweets: int hazardous_tweets: int results: List[dict] processing_time: float class HealthResponse(BaseModel): status: str message: str timestamp: str # Health check endpoint @app.get("/", response_model=HealthResponse) def health_check(): """Health check endpoint""" return HealthResponse( status="healthy", message="Ocean Hazard Detection API is running", timestamp=datetime.utcnow().isoformat() ) @app.get("/health", response_model=HealthResponse) def health(): """Alternative health check endpoint""" return health_check() # Main analysis endpoint @app.post("/analyze", response_model=TweetAnalysisResponse) async def analyze_tweets(request: TweetAnalysisRequest): """ Analyze tweets for ocean hazards - **limit**: Number of tweets to analyze (1-50) - **query**: Custom search query (optional) """ start_time = datetime.utcnow() try: logger.info(f"Starting analysis with limit: {request.limit}") # Fetch tweets if request.query: # Use custom query if provided from scraper import search_tweets, extract_tweets result = search_tweets(request.query, limit=request.limit) tweets = extract_tweets(result) else: # Use default hazard query tweets = fetch_hazard_tweets(limit=request.limit) logger.info(f"Fetched {len(tweets)} tweets") # Classify tweets results = classify_tweets(tweets) logger.info(f"Classified {len(results)} tweets") # Store hazardous tweets in database hazardous_count = 0 try: for r in results: if r.get('hazardous') == 1: hazardous_count += 1 hazards = (r.get('ner') or {}).get('hazards') or [] hazard_type = ", ".join(hazards) if hazards else "unknown" locs = (r.get('ner') or {}).get('locations') or [] if not locs and r.get('location'): locs = [r['location']] location = ", ".join(locs) if locs else "unknown" sentiment = r.get('sentiment') or {"label": "unknown", "score": 0.0} created_at = r.get('created_at') or "" tweet_date = "" tweet_time = "" if created_at: dt = None try: dt = parsedate_to_datetime(created_at) except Exception: dt = None if dt is None and 'T' in created_at: try: iso = created_at.replace('Z', '+00:00') dt = datetime.fromisoformat(iso) except Exception: dt = None if dt is not None: tweet_date = dt.date().isoformat() tweet_time = dt.time().strftime('%H:%M:%S') upsert_hazardous_tweet( tweet_url=r.get('tweet_url') or "", hazard_type=hazard_type, location=location, sentiment_label=sentiment.get('label', 'unknown'), sentiment_score=float(sentiment.get('score', 0.0)), tweet_date=tweet_date, tweet_time=tweet_time, ) logger.info(f"Stored {hazardous_count} hazardous tweets in database") except Exception as db_error: logger.warning(f"Database storage failed: {db_error}. Results will not be persisted.") # Calculate processing time processing_time = (datetime.utcnow() - start_time).total_seconds() return TweetAnalysisResponse( total_tweets=len(results), hazardous_tweets=hazardous_count, results=results, processing_time=processing_time ) except Exception as e: logger.error(f"Analysis failed: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) # Get stored hazardous tweets @app.get("/hazardous-tweets") async def get_hazardous_tweets(limit: int = 100, offset: int = 0): """ Get stored hazardous tweets from database - **limit**: Maximum number of tweets to return (default: 100) - **offset**: Number of tweets to skip (default: 0) """ try: from pg_db import get_conn with get_conn() as conn: with conn.cursor() as cur: cur.execute(""" SELECT tweet_url, hazard_type, location, sentiment_label, sentiment_score, tweet_date, tweet_time, inserted_at FROM hazardous_tweets ORDER BY inserted_at DESC LIMIT %s OFFSET %s """, (limit, offset)) columns = [desc[0] for desc in cur.description] results = [dict(zip(columns, row)) for row in cur.fetchall()] return { "tweets": results, "count": len(results), "limit": limit, "offset": offset } except Exception as e: logger.error(f"Failed to fetch hazardous tweets: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) # Get statistics @app.get("/stats") async def get_stats(): """Get analysis statistics""" try: from pg_db import get_conn with get_conn() as conn: with conn.cursor() as cur: # Total hazardous tweets cur.execute("SELECT COUNT(*) FROM hazardous_tweets") total_hazardous = cur.fetchone()[0] # By hazard type cur.execute(""" SELECT hazard_type, COUNT(*) as count FROM hazardous_tweets GROUP BY hazard_type ORDER BY count DESC """) hazard_types = [{"type": row[0], "count": row[1]} for row in cur.fetchall()] # By location cur.execute(""" SELECT location, COUNT(*) as count FROM hazardous_tweets WHERE location != 'unknown' GROUP BY location ORDER BY count DESC LIMIT 10 """) locations = [{"location": row[0], "count": row[1]} for row in cur.fetchall()] # By sentiment cur.execute(""" SELECT sentiment_label, COUNT(*) as count FROM hazardous_tweets GROUP BY sentiment_label ORDER BY count DESC """) sentiments = [{"sentiment": row[0], "count": row[1]} for row in cur.fetchall()] return { "total_hazardous_tweets": total_hazardous, "hazard_types": hazard_types, "top_locations": locations, "sentiment_distribution": sentiments } except Exception as e: logger.error(f"Failed to fetch statistics: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)