Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
import joblib | |
import numpy as np | |
import pandas as pd | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # Allows all origins | |
allow_credentials=True, | |
allow_methods=["*"], # Allows all methods | |
allow_headers=["*"], # Allows all headers | |
) | |
# Loading the model and label encoder | |
model = joblib.load("soil_npk_joblib_model.joblib") | |
le = joblib.load("label_encoder.joblib") | |
class InputData(BaseModel): | |
crop_name: str | |
target_yield: float | |
field_size: float | |
ph: float | |
organic_carbon: float | |
nitrogen: float | |
phosphorus: float | |
potassium: float | |
soil_moisture: float | |
async def predict(data: InputData): | |
try: | |
# Validating crop_name | |
if data.crop_name not in le.classes_: | |
raise ValueError(f"Invalid crop_name: {data.crop_name}") | |
input_data = pd.DataFrame({ | |
'crop_name': [data.crop_name], | |
'target_yield': [data.target_yield], | |
'field_size': [data.field_size], | |
'ph': [data.ph], | |
'organic_carbon': [data.organic_carbon], | |
'nitrogen': [data.nitrogen], | |
'phosphorus': [data.phosphorus], | |
'potassium': [data.potassium], | |
'soil_moisture': [data.soil_moisture] | |
}) | |
# Use the encoder to transform the crop_name | |
input_data['crop_name'] = le.transform(input_data['crop_name']) | |
# Validating the input shape | |
expected_shape = model.n_features_in_ | |
if input_data.shape[1] != expected_shape: | |
raise ValueError(f"Input shape mismatch. Expected {expected_shape} features, got {input_data.shape[1]}") | |
prediction = model.predict(input_data) | |
return { | |
"nitrogen_need": float(prediction[0][0]), | |
"phosphorus_need": float(prediction[0][1]), | |
"potassium_need": float(prediction[0][2]), | |
"organic_matter_need": float(prediction[0][3]), | |
"lime_need": float(prediction[0][4]) | |
} | |
except Exception as e: | |
logging.error(f"Error in predict function: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def root(): | |
return {"message": "NPK Needs Prediction API"} |