"""Minimal Flask API with token-based authentication middleware.""" from __future__ import annotations import os from datetime import datetime, timezone from typing import Any, Dict, Iterable, List, Mapping, MutableMapping from uuid import UUID, uuid4 import psycopg from psycopg import sql from psycopg.rows import dict_row from dotenv import load_dotenv from flask import Flask, abort, jsonify, request, g from flask_cors import CORS load_dotenv() API_TOKEN = os.getenv("API_TOKEN") REQUIRED_DB_SETTINGS = { "DB_NAME": os.getenv("DB_NAME", ""), "DB_HOST": os.getenv("DB_HOST", ""), "DB_PORT": os.getenv("DB_PORT", ""), "DB_USERNAME": os.getenv("DB_USERNAME", ""), "DB_PASSWORD": os.getenv("DB_PASSWORD", ""), } missing_db_settings = [ name for name, value in REQUIRED_DB_SETTINGS.items() if not value ] if missing_db_settings: missing = ", ".join(missing_db_settings) raise RuntimeError( f"Database configuration missing for: {missing}. Did you configure the .env file?" ) DB_NAME = REQUIRED_DB_SETTINGS["DB_NAME"] DB_HOST = REQUIRED_DB_SETTINGS["DB_HOST"] DB_USERNAME = REQUIRED_DB_SETTINGS["DB_USERNAME"] DB_PASSWORD = REQUIRED_DB_SETTINGS["DB_PASSWORD"] try: DB_PORT = int(REQUIRED_DB_SETTINGS["DB_PORT"]) except ValueError as exc: raise RuntimeError("DB_PORT must be an integer") from exc USER_TABLE = os.getenv("DB_TABLE_USERS", "auth_user") INVESTMENT_PROFILE_TABLE = os.getenv( "DB_TABLE_INVESTMENT_PROFILES", "users_investmentprofile" ) SCRAPER_TABLE = os.getenv("DB_TABLE_SCRAPERS", "scraper") if not API_TOKEN: raise RuntimeError( "API_TOKEN missing from environment. Did you configure the .env file?" ) app = Flask(__name__) CORS(app, resources={r"/*": {"origins": "*"}}, supports_credentials=True) def get_db_connection() -> psycopg.Connection: connection = psycopg.connect( dbname=DB_NAME, user=DB_USERNAME, password=DB_PASSWORD, host=DB_HOST, port=DB_PORT, row_factory=dict_row, ) connection.autocommit = True return connection def get_db() -> psycopg.Connection: if "db_connection" not in g: try: g.db_connection = get_db_connection() except psycopg.OperationalError: abort(503, description="Database connection failed") return g.db_connection @app.teardown_appcontext def close_db_connection(_: BaseException | None) -> None: db_connection = g.pop("db_connection", None) if db_connection is not None: db_connection.close() def _get_json_body() -> MutableMapping[str, Any]: payload = request.get_json(silent=True) if not isinstance(payload, MutableMapping): abort(400, description="Request body must be a JSON object") return payload def _parse_bool(value: Any, field_name: str) -> bool: if isinstance(value, bool): return value if isinstance(value, str): lowered = value.strip().lower() if lowered in {"true", "1", "yes", "y"}: return True if lowered in {"false", "0", "no", "n"}: return False if isinstance(value, (int, float)): if value in {0, 1}: return bool(value) abort(400, description=f"Field '{field_name}' must be a boolean value") def _parse_datetime(value: Any, field_name: str) -> datetime | None: if value is None: return None if isinstance(value, datetime): dt = value elif isinstance(value, str): try: normalized = value.replace("Z", "+00:00") dt = datetime.fromisoformat(normalized) except ValueError: abort( 400, description=f"Field '{field_name}' must be a valid ISO 8601 datetime", ) else: abort( 400, description=f"Field '{field_name}' must be a valid ISO 8601 datetime" ) if dt.tzinfo is None: dt = dt.replace(tzinfo=timezone.utc) else: dt = dt.astimezone(timezone.utc) return dt def _parse_string(value: Any, field_name: str, *, allow_empty: bool = False) -> str: if not isinstance(value, str): abort(400, description=f"Field '{field_name}' must be a string") stripped = value.strip() if not allow_empty and not stripped: abort(400, description=f"Field '{field_name}' cannot be empty") return stripped if not allow_empty else value def _isoformat(dt: datetime | None) -> str | None: if dt is None: return None if dt.tzinfo is None: dt = dt.replace(tzinfo=timezone.utc) return dt.astimezone(timezone.utc).isoformat().replace("+00:00", "Z") def _ensure_json_compatible(value: Any, field_name: str) -> Any: if isinstance(value, (dict, list, str, int, float, bool)) or value is None: return value abort(400, description=f"Field '{field_name}' must be valid JSON data") def _parse_int(value: Any, field_name: str) -> int: if isinstance(value, int): return value if isinstance(value, str): try: return int(value, 10) except ValueError: abort(400, description=f"Field '{field_name}' must be an integer") abort(400, description=f"Field '{field_name}' must be an integer") def _parse_uuid(value: Any, field_name: str) -> UUID: if isinstance(value, UUID): return value if isinstance(value, str): try: return UUID(value) except ValueError: abort(400, description=f"Field '{field_name}' must be a valid UUID") abort(400, description=f"Field '{field_name}' must be a valid UUID") def _require_bearer_token(header_value: str | None) -> str: if not header_value: abort(401, description="Missing bearer token") parts = header_value.strip().split() if len(parts) != 2 or parts[0].lower() != "bearer": abort(401, description="Authorization header must be 'Bearer '") token = parts[1].strip() if not token: abort(401, description="Authorization header must include a token") return token def _serialize_row( row: Mapping[str, Any], *, datetime_fields: Iterable[str] | None = None ) -> Dict[str, Any]: result: Dict[str, Any] = dict(row) for field in datetime_fields or (): result[field] = _isoformat(result.get(field)) for field, value in list(result.items()): if isinstance(value, UUID): result[field] = str(value) return result def _fetch_one( query: sql.Composed, params: Mapping[str, Any] ) -> Mapping[str, Any] | None: conn = get_db() with conn.cursor() as cur: cur.execute(query, params) return cur.fetchone() def _fetch_all( query: sql.Composed, params: Mapping[str, Any] | None = None ) -> List[Mapping[str, Any]]: conn = get_db() with conn.cursor() as cur: cur.execute(query, params or {}) return cur.fetchall() def _execute(query: sql.Composed, params: Mapping[str, Any]) -> None: conn = get_db() with conn.cursor() as cur: cur.execute(query, params) def _columns_sql(columns: Iterable[str]) -> sql.Composed: return sql.SQL(", ").join(sql.Identifier(column) for column in columns) def _placeholders(columns: Iterable[str]) -> sql.Composed: return sql.SQL(", ").join(sql.Placeholder(column) for column in columns) def _insert_row( table: str, data: Mapping[str, Any], returning: Iterable[str] ) -> Mapping[str, Any]: if not data: raise ValueError("Cannot insert without data") query = sql.SQL( "INSERT INTO {table} ({columns}) VALUES ({values}) RETURNING {returning}" ).format( table=sql.Identifier(table), columns=_columns_sql(data.keys()), values=_placeholders(data.keys()), returning=_columns_sql(returning), ) row = _fetch_one(query, data) if row is None: raise RuntimeError("Insert statement did not return a row") return row def _update_row( table: str, identifier_column: str, identifier_value: Any, data: Mapping[str, Any], returning: Iterable[str], ) -> Mapping[str, Any] | None: if not data: raise ValueError("Cannot update without data") assignments = sql.SQL(", ").join( sql.SQL("{column} = {placeholder}").format( column=sql.Identifier(column), placeholder=sql.Placeholder(column), ) for column in data.keys() ) query = sql.SQL( "UPDATE {table} SET {assignments} WHERE {identifier_column} = {identifier} " "RETURNING {returning}" ).format( table=sql.Identifier(table), assignments=assignments, identifier_column=sql.Identifier(identifier_column), identifier=sql.Placeholder("identifier"), returning=_columns_sql(returning), ) params: Dict[str, Any] = dict(data) params["identifier"] = identifier_value return _fetch_one(query, params) def _delete_row(table: str, identifier_column: str, identifier_value: Any) -> bool: query = sql.SQL( "DELETE FROM {table} WHERE {identifier_column} = {identifier}" ).format( table=sql.Identifier(table), identifier_column=sql.Identifier(identifier_column), identifier=sql.Placeholder("identifier"), ) conn = get_db() with conn.cursor() as cur: cur.execute(query, {"identifier": identifier_value}) return cur.rowcount > 0 def _abort_for_integrity_error(exc: psycopg.IntegrityError) -> None: detail = getattr(getattr(exc, "diag", None), "message_detail", None) abort(409, description=detail or "Database constraint violation") USER_RESPONSE_FIELDS = ( "id", "username", "first_name", "last_name", "email", "is_superuser", "is_staff", "is_active", "date_joined", "last_login", ) USER_DATETIME_FIELDS = ("date_joined", "last_login") USER_BOOL_FIELDS = ("is_superuser", "is_staff", "is_active") PROFILE_RESPONSE_FIELDS = ( "profile_id", "name", "description", "criteria", "created_at", "is_active", ) PROFILE_DATETIME_FIELDS = ("created_at",) SCRAPER_RESPONSE_FIELDS = ( "id", "params", "last_seen_days", "first_seen_days", "frequency", "task_name", "enabled", "property_types", "page_size", "max_pages", "enrich_llm", "only_match", ) SCRAPER_INT_FIELDS = ( "last_seen_days", "first_seen_days", "page_size", "max_pages", "enabled", "enrich_llm", "only_match", ) @app.before_request def enforce_bearer_token() -> None: if request.method == "OPTIONS": return # Allow Flask internals without auth. if request.endpoint == "static": return provided_token = _require_bearer_token(request.headers.get("Authorization")) if provided_token != API_TOKEN: abort(401, description="Invalid bearer token") @app.get("/profiles") def get_profiles(): rows = _fetch_all( sql.SQL("SELECT {columns} FROM {table} ORDER BY created_at DESC").format( columns=_columns_sql(PROFILE_RESPONSE_FIELDS), table=sql.Identifier(INVESTMENT_PROFILE_TABLE), ) ) payload = [ _serialize_row(row, datetime_fields=PROFILE_DATETIME_FIELDS) for row in rows ] return jsonify(payload) @app.get("/profiles/") def get_profile(profile_id: str): profile_uuid = _parse_uuid(profile_id, "profile_id") row = _fetch_one( sql.SQL("SELECT {columns} FROM {table} WHERE profile_id = {identifier}").format( columns=_columns_sql(PROFILE_RESPONSE_FIELDS), table=sql.Identifier(INVESTMENT_PROFILE_TABLE), identifier=sql.Placeholder("profile_id"), ), {"profile_id": profile_uuid}, ) if row is None: abort(404, description="Profile not found") return jsonify(_serialize_row(row, datetime_fields=PROFILE_DATETIME_FIELDS)) @app.post("/profiles") def create_profile(): payload = _get_json_body() profile_identifier = payload.get("profile_id") profile_uuid = ( _parse_uuid(profile_identifier, "profile_id") if profile_identifier else uuid4() ) name = _parse_string(payload.get("name"), "name") description_value = payload.get("description") description = ( None if description_value is None else _parse_string(description_value, "description", allow_empty=True) ) criteria_raw = payload.get("criteria") if criteria_raw is None: abort(400, description="Field 'criteria' is required") criteria = _ensure_json_compatible(criteria_raw, "criteria") created_at_value = payload.get("created_at") created_at = ( datetime.now(timezone.utc) if created_at_value is None else _parse_datetime(created_at_value, "created_at") ) is_active = _parse_bool(payload.get("is_active", True), "is_active") data = { "profile_id": profile_uuid, "name": name, "description": description, "criteria": criteria, "created_at": created_at, "is_active": is_active, } try: row = _insert_row(INVESTMENT_PROFILE_TABLE, data, PROFILE_RESPONSE_FIELDS) except psycopg.IntegrityError as exc: _abort_for_integrity_error(exc) return ( jsonify(_serialize_row(row, datetime_fields=PROFILE_DATETIME_FIELDS)), 201, ) @app.put("/profiles/") def update_profile(profile_id: str): profile_uuid = _parse_uuid(profile_id, "profile_id") payload = _get_json_body() updates: Dict[str, Any] = {} if "name" in payload: updates["name"] = _parse_string(payload["name"], "name") if "description" in payload: description_value = payload["description"] updates["description"] = ( None if description_value is None else _parse_string(description_value, "description", allow_empty=True) ) if "criteria" in payload: updates["criteria"] = _ensure_json_compatible(payload["criteria"], "criteria") if "created_at" in payload: updates["created_at"] = _parse_datetime(payload["created_at"], "created_at") if "is_active" in payload: updates["is_active"] = _parse_bool(payload["is_active"], "is_active") if not updates: abort(400, description="No updatable fields provided") try: row = _update_row( INVESTMENT_PROFILE_TABLE, "profile_id", profile_uuid, updates, PROFILE_RESPONSE_FIELDS, ) except psycopg.IntegrityError as exc: _abort_for_integrity_error(exc) if row is None: abort(404, description="Profile not found") return jsonify(_serialize_row(row, datetime_fields=PROFILE_DATETIME_FIELDS)) @app.delete("/profiles/") def delete_profile(profile_id: str): profile_uuid = _parse_uuid(profile_id, "profile_id") deleted = _delete_row(INVESTMENT_PROFILE_TABLE, "profile_id", profile_uuid) if not deleted: abort(404, description="Profile not found") return "", 204 @app.get("/users") def get_users(): rows = _fetch_all( sql.SQL("SELECT {columns} FROM {table} ORDER BY id").format( columns=_columns_sql(USER_RESPONSE_FIELDS), table=sql.Identifier(USER_TABLE), ) ) payload = [ _serialize_row(row, datetime_fields=USER_DATETIME_FIELDS) for row in rows ] return jsonify(payload) @app.get("/users/") def get_user(user_id: int): row = _fetch_one( sql.SQL("SELECT {columns} FROM {table} WHERE id = {identifier}").format( columns=_columns_sql(USER_RESPONSE_FIELDS), table=sql.Identifier(USER_TABLE), identifier=sql.Placeholder("user_id"), ), {"user_id": user_id}, ) if row is None: abort(404, description="User not found") return jsonify(_serialize_row(row, datetime_fields=USER_DATETIME_FIELDS)) @app.post("/users") def create_user(): payload = _get_json_body() user_data: Dict[str, Any] = {} user_data["password"] = _parse_string(payload.get("password"), "password") user_data["username"] = _parse_string(payload.get("username"), "username") user_data["first_name"] = _parse_string(payload.get("first_name"), "first_name") user_data["last_name"] = _parse_string(payload.get("last_name"), "last_name") user_data["email"] = _parse_string(payload.get("email"), "email") user_data["is_superuser"] = _parse_bool( payload.get("is_superuser", False), "is_superuser" ) user_data["is_staff"] = _parse_bool(payload.get("is_staff", False), "is_staff") user_data["is_active"] = _parse_bool(payload.get("is_active", True), "is_active") user_data["date_joined"] = _parse_datetime( payload.get("date_joined"), "date_joined" ) if user_data["date_joined"] is None: user_data["date_joined"] = datetime.now(timezone.utc) user_data["last_login"] = _parse_datetime(payload.get("last_login"), "last_login") try: row = _insert_row(USER_TABLE, user_data, USER_RESPONSE_FIELDS) except psycopg.IntegrityError as exc: _abort_for_integrity_error(exc) return ( jsonify(_serialize_row(row, datetime_fields=USER_DATETIME_FIELDS)), 201, ) @app.put("/users/") def update_user(user_id: int): payload = _get_json_body() updates: Dict[str, Any] = {} if "password" in payload: updates["password"] = _parse_string(payload["password"], "password") if "username" in payload: updates["username"] = _parse_string(payload["username"], "username") if "first_name" in payload: updates["first_name"] = _parse_string(payload["first_name"], "first_name") if "last_name" in payload: updates["last_name"] = _parse_string(payload["last_name"], "last_name") if "email" in payload: updates["email"] = _parse_string(payload["email"], "email") for field in USER_BOOL_FIELDS: if field in payload: updates[field] = _parse_bool(payload[field], field) if "date_joined" in payload: updates["date_joined"] = _parse_datetime(payload["date_joined"], "date_joined") if "last_login" in payload: updates["last_login"] = _parse_datetime(payload["last_login"], "last_login") if not updates: abort(400, description="No updatable fields provided") try: row = _update_row(USER_TABLE, "id", user_id, updates, USER_RESPONSE_FIELDS) except psycopg.IntegrityError as exc: _abort_for_integrity_error(exc) if row is None: abort(404, description="User not found") return jsonify(_serialize_row(row, datetime_fields=USER_DATETIME_FIELDS)) @app.delete("/users/") def delete_user(user_id: int): deleted = _delete_row(USER_TABLE, "id", user_id) if not deleted: abort(404, description="User not found") return "", 204 @app.get("/scrapers") def get_scrapers(): rows = _fetch_all( sql.SQL("SELECT {columns} FROM {table} ORDER BY id").format( columns=_columns_sql(SCRAPER_RESPONSE_FIELDS), table=sql.Identifier(SCRAPER_TABLE), ) ) return jsonify([dict(row) for row in rows]) @app.get("/scrapers/") def get_scraper(scraper_id: str): row = _fetch_one( sql.SQL("SELECT {columns} FROM {table} WHERE id = {identifier}").format( columns=_columns_sql(SCRAPER_RESPONSE_FIELDS), table=sql.Identifier(SCRAPER_TABLE), identifier=sql.Placeholder("scraper_id"), ), {"scraper_id": _parse_string(scraper_id, "id")}, ) if row is None: abort(404, description="Scraper not found") return jsonify(dict(row)) @app.post("/scrapers") def create_scraper(): payload = _get_json_body() scraper_id = _parse_string(payload.get("id"), "id") data: Dict[str, Any] = {"id": scraper_id} for field in ("params", "frequency", "task_name", "property_types"): if field in payload: value = payload[field] data[field] = ( None if value is None else _parse_string(value, field, allow_empty=True) ) for field in SCRAPER_INT_FIELDS: if field in payload: value = payload[field] data[field] = None if value is None else _parse_int(value, field) try: row = _insert_row(SCRAPER_TABLE, data, SCRAPER_RESPONSE_FIELDS) except psycopg.IntegrityError as exc: _abort_for_integrity_error(exc) return jsonify(dict(row)), 201 @app.put("/scrapers/") def update_scraper(scraper_id: str): payload = _get_json_body() updates: Dict[str, Any] = {} for field in ("params", "frequency", "task_name", "property_types"): if field in payload: value = payload[field] updates[field] = ( None if value is None else _parse_string(value, field, allow_empty=True) ) for field in SCRAPER_INT_FIELDS: if field in payload: value = payload[field] updates[field] = None if value is None else _parse_int(value, field) if not updates: abort(400, description="No updatable fields provided") try: row = _update_row( SCRAPER_TABLE, "id", _parse_string(scraper_id, "id"), updates, SCRAPER_RESPONSE_FIELDS, ) except psycopg.IntegrityError as exc: _abort_for_integrity_error(exc) if row is None: abort(404, description="Scraper not found") return jsonify(dict(row)) @app.delete("/scrapers/") def delete_scraper(scraper_id: str): deleted = _delete_row(SCRAPER_TABLE, "id", _parse_string(scraper_id, "id")) if not deleted: abort(404, description="Scraper not found") return "", 204 if __name__ == "__main__": app.run(host="0.0.0.0", port=int(os.getenv("PORT", "8000")), debug=False)