用Tensorflow和FastAPI构建图像分类API,很多新手对此不是很清楚,为了帮助大家解决这个难题,下面小编将为大家详细讲解,有这方面需求的人可以来学习下,希望你能有所收获。

创新互联公司长期为超过千家客户提供的网站建设服务,团队从业经验10年,关注不同地域、不同群体,并针对不同对象提供差异化的产品和服务;打造开放共赢平台,与合作伙伴共同营造健康的互联网生态环境。为新宁企业提供专业的网站设计、做网站,新宁网站改版等技术服务。拥有十载丰富建站经验和众多成功案例,为您定制开发。
让我们从一个简单的helloworld示例开始
首先,我们导入FastAPI类并创建一个对象应用程序。这个类有一些有用的参数,比如我们可以传递swaggerui的标题和描述。
from fastapi import FastAPI app = FastAPI(title='Hello world')
我们定义一个函数并用@app.get. 这意味着我们的API/index支持GET方法。这里定义的函数是异步的,FastAPI通过为普通的def函数创建线程池来自动处理异步和不使用异步方法,并且它为异步函数使用异步事件循环。
@app.get('/index')
async def hello_world():
    return "hello world"我们将创建一个API来对图像进行分类,我们将其命名为predict/image。我们将使用Tensorflow来创建图像分类模型。
Tensorflow图像分类教程:https://aniketmaurya.ml/blog/tensorflow/deep%20learning/2019/05/12/image-classification-with-tf2.html
我们创建了一个函数load_model,它将返回一个带有预训练权重的MobileNet CNN模型,即它已经被训练为对1000个不同类别的图像进行分类。
import tensorflow as tf
def load_model():
    model = tf.keras.applications.MobileNetV2(weights="imagenet")
    print("Model loaded")
    return model
    
model = load_model()我们定义了一个predict函数,它将接受图像并返回预测。我们将图像大小调整为224x224,并将像素值规格化为[-1,1]。
from tensorflow.keras.applications.imagenet_utils import decode_predictions
decode_predictions用于解码预测对象的类名。这里我们将返回前2个可能的类。
def predict(image: Image.Image):
    image = np.asarray(image.resize((224, 224)))[..., :3]
    image = np.expand_dims(image, 0)
    image = image / 127.5 - 1.0
    
    result = decode_predictions(model.predict(image), 2)[0]
    
    response = []
    
    for i, res in enumerate(result):
        resp = {}
        resp["class"] = res[1]
        resp["confidence"] = f"{res[2]*100:0.2f} %"
        
        response.append(resp)
        
    return response现在我们将创建一个支持文件上传的API/predict/image。我们将过滤文件扩展名以仅支持jpg、jpeg和png格式的图像。
我们将使用Pillow加载上传的图像。
def read_imagefile(file) -> Image.Image:
    image = Image.open(BytesIO(file))
    return image
    
@app.post("/predict/image")
async def predict_api(file: UploadFile = File(...)):
    extension = file.filename.split(".")[-1] in ("jpg", "jpeg", "png")
    if not extension:
        return "Image must be jpg or png format!"
    image = read_imagefile(await file.read())
    prediction = predict(image)
    
    return predictionimport uvicorn
from fastapi import FastAPI, File, UploadFile
from application.components import predict, read_imagefile
app = FastAPI()
@app.post("/predict/image")
async def predict_api(file: UploadFile = File(...)):
    extension = file.filename.split(".")[-1] in ("jpg", "jpeg", "png")
    if not extension:
        return "Image must be jpg or png format!"
    image = read_imagefile(await file.read())
    prediction = predict(image)
    
    return prediction
    
@app.post("/api/covid-symptom-check")
def check_risk(symptom: Symptom):
    return symptom_check.get_risk_level(symptom)
    
if __name__ == "__main__":
    uvicorn.run(app, debug=True)看完上述内容是否对您有帮助呢?如果还想对相关知识有进一步的了解或阅读更多相关文章,请关注创新互联行业资讯频道,感谢您对创新互联的支持。