models.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. """
  2. Models Router - Proxy to Model Service
  3. """
  4. from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks
  5. from fastapi.responses import JSONResponse
  6. from typing import List, Optional
  7. import httpx
  8. import asyncio
  9. from shared.config import settings
  10. from shared.auth import get_current_user
  11. from shared.models.user import User
  12. from shared.models.model import ModelRequest, ModelResponse, ModelList
  13. router = APIRouter()
  14. # Model service URL
  15. MODEL_SERVICE_URL = f"http://model-service:8001"
  16. @router.get("/models", response_model=List[ModelList])
  17. async def list_models(
  18. current_user: User = Depends(get_current_user),
  19. skip: int = 0,
  20. limit: int = 100
  21. ):
  22. """List available models"""
  23. try:
  24. async with httpx.AsyncClient() as client:
  25. response = await client.get(
  26. f"{MODEL_SERVICE_URL}/models",
  27. params={"skip": skip, "limit": limit},
  28. headers={"Authorization": f"Bearer {current_user.token}"}
  29. )
  30. response.raise_for_status()
  31. return response.json()
  32. except httpx.HTTPError as e:
  33. raise HTTPException(status_code=500, detail=f"Model service error: {str(e)}")
  34. @router.post("/models/{model_id}/predict", response_model=ModelResponse)
  35. async def predict(
  36. model_id: str,
  37. request: ModelRequest,
  38. background_tasks: BackgroundTasks,
  39. current_user: User = Depends(get_current_user)
  40. ):
  41. """Make prediction with specified model"""
  42. try:
  43. async with httpx.AsyncClient() as client:
  44. response = await client.post(
  45. f"{MODEL_SERVICE_URL}/models/{model_id}/predict",
  46. json=request.dict(),
  47. headers={"Authorization": f"Bearer {current_user.token}"}
  48. )
  49. response.raise_for_status()
  50. # Log prediction request
  51. background_tasks.add_task(
  52. log_prediction_request,
  53. model_id,
  54. current_user.id,
  55. request
  56. )
  57. return response.json()
  58. except httpx.HTTPError as e:
  59. raise HTTPException(status_code=500, detail=f"Model service error: {str(e)}")
  60. @router.get("/models/{model_id}", response_model=ModelList)
  61. async def get_model(
  62. model_id: str,
  63. current_user: User = Depends(get_current_user)
  64. ):
  65. """Get model details"""
  66. try:
  67. async with httpx.AsyncClient() as client:
  68. response = await client.get(
  69. f"{MODEL_SERVICE_URL}/models/{model_id}",
  70. headers={"Authorization": f"Bearer {current_user.token}"}
  71. )
  72. response.raise_for_status()
  73. return response.json()
  74. except httpx.HTTPError as e:
  75. raise HTTPException(status_code=500, detail=f"Model service error: {str(e)}")
  76. @router.post("/models/{model_id}/train")
  77. async def train_model(
  78. model_id: str,
  79. training_data: dict,
  80. background_tasks: BackgroundTasks,
  81. current_user: User = Depends(get_current_user)
  82. ):
  83. """Start model training"""
  84. try:
  85. async with httpx.AsyncClient() as client:
  86. response = await client.post(
  87. f"{MODEL_SERVICE_URL}/models/{model_id}/train",
  88. json=training_data,
  89. headers={"Authorization": f"Bearer {current_user.token}"}
  90. )
  91. response.raise_for_status()
  92. # Log training request
  93. background_tasks.add_task(
  94. log_training_request,
  95. model_id,
  96. current_user.id,
  97. training_data
  98. )
  99. return response.json()
  100. except httpx.HTTPError as e:
  101. raise HTTPException(status_code=500, detail=f"Model service error: {str(e)}")
  102. @router.get("/models/{model_id}/status")
  103. async def get_model_status(
  104. model_id: str,
  105. current_user: User = Depends(get_current_user)
  106. ):
  107. """Get model training status"""
  108. try:
  109. async with httpx.AsyncClient() as client:
  110. response = await client.get(
  111. f"{MODEL_SERVICE_URL}/models/{model_id}/status",
  112. headers={"Authorization": f"Bearer {current_user.token}"}
  113. )
  114. response.raise_for_status()
  115. return response.json()
  116. except httpx.HTTPError as e:
  117. raise HTTPException(status_code=500, detail=f"Model service error: {str(e)}")
  118. async def log_prediction_request(model_id: str, user_id: str, request: ModelRequest):
  119. """Log prediction request for analytics"""
  120. # Implementation would log to database or analytics service
  121. pass
  122. async def log_training_request(model_id: str, user_id: str, training_data: dict):
  123. """Log training request for analytics"""
  124. # Implementation would log to database or analytics service
  125. pass