-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
58 lines (49 loc) · 2.06 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import os
from flask import Flask, request, jsonify
from flask_cors import CORS
import requests
app = Flask(__name__)
@app.after_request
def apply_cors(response):
response.headers["Access-Control-Allow-Origin"] = "*"
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS"
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization"
return response
# Read the environment variable
ENV = os.getenv('FLASK_ENV', 'production') # Default to 'development' if not set
deployed_model_name = "fraud" # Ensure that this is same as the model name you gave on OpenShift AI
rest_url = "http://modelmesh-serving:8008" # Model Server endpoint
# Construct the inference URL for our model. Change deployed_moodel_name if you change the name of the model
infer_url = f"{rest_url}/v2/models/{deployed_model_name}/infer"
#Load the scaler.pkl that contained pre-trained scikit-learn scaler used to standardize or normalize input data. This ensures that the data fed into your model during inference is scaled consistently with how it was during training, improving model accuracy and performance
import pickle
with open('artifact/scaler.pkl', 'rb') as handle:
scaler = pickle.load(handle)
# Handle the request to the model server
def rest_request(data):
json_data = {
"inputs": [
{
"name": "dense_input",
"shape": [1, 5],
"datatype": "FP32",
"data": data
}
]
}
response = requests.post(infer_url, json=json_data)
response_dict = response.json()
return response_dict['outputs'][0]['data']
# Endpoint
@app.route('/', methods=['POST'])
def check_fraud():
data = request.json
prediction = rest_request(scaler.transform([data]).tolist()[0]) # place a request to the model server from this service
threshhold = 0.95
if (prediction[0] > threshhold):
message = 'fraud'
else:
message = 'not fraud'
return jsonify({'message': message})
if __name__ == '__main__':
app.run(debug=True, port=5000)