import json
import numpy as np
from flask import Flask, request,jsonify
from gevent.pywsgi import WSGIServer
from sentence_transformers import SentenceTransformer
app = Flask(__name__)
@app.route('/embedding', methods=['GET','POST'])
def health_check():
doc_list = request.get_json().get("doc_list")
model_name = request.get_json().get("model_name")
data = request_parse(doc_list,model_name)
return data, 200
class json_serialize(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.integer):
return int(obj)
if isinstance(obj, np.floating):
return float(obj)
if isinstance(obj, np.ndarray):
return obj.tolist()
return json.JSONEncoder.default(self, obj)
def request_parse(doc_list,model_name):
# 准备一些测试文本,用于生成文本向量
# 调用模型生成文本向量
embedding_path = '/root/.cache/huggingface/'
model = SentenceTransformer(embedding_path + model_name)
embeddings = model.encode(doc_list)
json_array = json.dumps({'data': embeddings}, cls=json_serialize)
return json_array
if __name__ == "__main__":
# app.run(debug=False, port=5000)
http_server = WSGIServer(('127.0.0.1', 5000), app)
http_server.serve_forever()