Update finance api project to include rest api

This commit is contained in:
thecodebranch 2024-07-04 09:50:20 -06:00
parent 8a4f34855f
commit f630987b08
18 changed files with 559 additions and 82 deletions

2
api/config.py Normal file
View File

@ -0,0 +1,2 @@
class ApiConfig:
DATABASE_URI = "sqlite:///finance.db"

23
api/errors.py Normal file
View File

@ -0,0 +1,23 @@
class Error(Exception):
def __init__(self, value=""):
if not hasattr(self, "value"):
self.value = value
def __str__(self):
return repr(self.value)
class InvalidSma(Error):
message = "Invalid SMA period."
class ResourceNotFound(Error):
message = "Resource not found."
class DateFormatError(Error):
message = "Error with date parameters."
class Messages:
INTERNAL_SERVER_ERROR = "Server error. please try again later."

3
api/service/base.py Normal file
View File

@ -0,0 +1,3 @@
class BaseService:
def __init__(self, database=None):
self.database = database

45
api/service/historical.py Normal file
View File

@ -0,0 +1,45 @@
from datetime import datetime
from .base import BaseService
from api.errors import DateFormatError
from api.errors import ResourceNotFound
from model.ticker import Ticker
def validate_date(date_str):
try:
date_obj = datetime.strptime(date_str, "%Y-%m-%d")
return date_obj
except ValueError:
raise DateFormatError
class HistoricalService(BaseService):
def get_historical_ticker(self, symbol, params):
from_date = params.get("from_date", "2000-01-01")
to_date = params.get("to_date", "2001-01-01")
from_date = validate_date(from_date)
to_date = validate_date(to_date)
historical = (
self.database.session.query(Ticker)
.filter(Ticker.symbol == symbol)
.filter(Ticker.date >= from_date)
.filter(Ticker.date <= to_date)
.all()
)
if not historical:
raise ResourceNotFound
output = {}
data = []
for ticker in historical:
ticker = ticker.to_dict()
ticker.pop("symbol")
data.append(ticker)
output["ticker"] = symbol
output["data"] = data
return output

79
api/service/sma.py Normal file
View File

@ -0,0 +1,79 @@
from .base import BaseService
from model.sma import Sma
from api.errors import ResourceNotFound
from api.errors import InvalidSma
class SmaService(BaseService):
def get_sma(self, ticker, params):
# Validate SMA periods and get corresponding columns
sma_periods = params.get("period")
sma_columns = self._validate_and_get_columns(sma_periods)
from_date = params.get("from_date", "2000-01-01")
to_date = params.get("to_date", "2001-01-01")
dates = {"from_date": from_date, "to_date": to_date}
# Query the database for the necessary columns
data = self._query_sma_data(ticker, sma_columns, dates)
if not data:
raise ResourceNotFound
data = self._format_results(data, sma_columns)
result = {
"ticker": ticker,
"data": data
}
return result
def _validate_and_get_columns(self, sma_periods):
valid_sma_columns = {
"5": Sma.sma_5,
"10": Sma.sma_10,
"20": Sma.sma_20,
"50": Sma.sma_50,
"100": Sma.sma_100,
"200": Sma.sma_200,
}
sma_columns = []
for period in sma_periods.split(","):
if period not in valid_sma_columns.keys():
raise InvalidSma
sma_columns.append(valid_sma_columns[period])
return sma_columns
def _query_sma_data(self, ticker, sma_columns, dates):
# Start with the base query selecting the date and symbol
query = self.database.session.query(Sma.symbol, Sma.date, Sma.close)
# Add the SMA columns dynamically
query = query.add_columns(*sma_columns)
# Filter by the ticker symbol
query = query.filter(Sma.symbol == ticker)
query = query.filter(Sma.date >= dates.get("from_date"))
query = query.filter(Sma.date <= dates.get("to_date"))
# Execute the query and return the results
return query.all()
def _format_results(self, data, sma_columns):
# Format the results into a list of dictionaries
result = []
for row in data:
row_dict = {
"date": row.date.strftime("%Y-%m-%d"),
"close": row.close,
}
for i, column in enumerate(sma_columns):
# +2 because the first two columns are date and symbol
sma = row[i + 3]
row_dict[column.key] = sma
result.append(row_dict)
return result

11
api/service/ticker.py Normal file
View File

@ -0,0 +1,11 @@
from model.tickerSymbol import TickerSymbol
from .base import BaseService
class TickerService(BaseService):
def get_tickers(self):
try:
tickers = self.database.session.query(TickerSymbol).all()
return [ticker.symbol for ticker in tickers]
except Exception as e:
pass

6
api/views/base.py Normal file
View File

@ -0,0 +1,6 @@
from flask.views import MethodView
class BaseView(MethodView):
def __init__(self, service=None):
self.service=service

19
api/views/historical.py Normal file
View File

@ -0,0 +1,19 @@
from flask import jsonify
from flask import request
from .base import BaseView
from api.errors import DateFormatError
from api.errors import Messages
from api.errors import ResourceNotFound
class HistoricalView(BaseView):
def get(self, ticker):
try:
data = self.service.get_historical_ticker(ticker, params=request.args)
return jsonify(data), 200
except ResourceNotFound as e:
return jsonify({"error": e.message}), 404
except DateFormatError as e:
return jsonify({"error": e.message}), 400
except Exception as e:
return jsonify({"error": Messages.INTERNAL_SERVER_ERROR}), 500

19
api/views/sma.py Normal file
View File

@ -0,0 +1,19 @@
from flask import jsonify
from flask import request
from .base import BaseView
from api.errors import ResourceNotFound
from api.errors import InvalidSma
from api.errors import Messages
class SmaView(BaseView):
def get(self, ticker):
try:
data = self.service.get_sma(ticker, request.args)
return jsonify({"data": data}), 200
except InvalidSma as e:
return jsonify({"error": e.message}), 400
except ResourceNotFound as e:
return jsonify({"error": e.message}), 400
except Exception as e:
return jsonify({"error": Messages.INTERNAL_SERVER_ERROR}), 500

12
api/views/ticker.py Normal file
View File

@ -0,0 +1,12 @@
from flask import jsonify
from .base import BaseView
from api.errors import Messages
class TickerView(BaseView):
def get(self):
try:
data = self.service.get_tickers()
return jsonify({"data": data}), 200
except Exception as e:
return jsonify({"error": Messages.INTERNAL_SERVER_ERROR}), 500

View File

@ -29,14 +29,14 @@ class StockDataGenerator:
self.volume_adjustment_factor = volume_adjustment_factor
def generate_raw_data(self, start_date="2000-01-01", end_date="2024-01-01"):
def generate_raw_data(self, output_path=None, start_date="2000-01-01", end_date="2024-01-01"):
tickers = self._generate_fake_tickers()
dates = self._generate_dates()
ticker_count = 0
for ticker in tickers:
daily_ticker_data = self._generate_stock_data(ticker, dates)
self._write_to_csv(daily_ticker_data)
self._write_to_csv(daily_ticker_data, output_path)
ticker_count += 1
print(f"Generated data for: {ticker} and is {ticker_count} of {len(tickers)}")
@ -109,8 +109,12 @@ class StockDataGenerator:
return stock_data
def _write_to_csv(self, data):
with open(self.csv_output_file, "a") as file:
def _write_to_csv(self, data, output=None):
if output:
output_file = output
else:
output_file = self.csv_output_file
with open(output_file, "a") as file:
fieldnames = ["date", "symbol", "open", "high", "low", "close", "volume"]
writer = csv.DictWriter(file, fieldnames=fieldnames)
if file.tell() == 0:

View File

@ -13,7 +13,7 @@ def calculate_sma(data, period):
sma_values.append(None)
else:
sma = sum(close_prices[i + 1 - period:i + 1]) / period
sma_values.append(sma)
sma_values.append(round(sma, 3))
for i in range(len(data)):
data[i]['SMA'] = sma_values[i]

View File

@ -1,9 +1,12 @@
from datetime import datetime
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import scoped_session
from model.base import Base
from sqlalchemy import MetaData, Table
from sqlalchemy import MetaData
from sqlalchemy import Table
from sqlalchemy import Index
from sqlalchemy import text
class Database:
@ -12,6 +15,18 @@ class Database:
self.session = scoped_session(sessionmaker(bind=self.engine))
def init_app(self, app):
self.engine = create_engine((app.config["DATABASE_URI"]))
self.session = scoped_session(sessionmaker(bind=self.engine))
# flask request handlers.
@app.after_request
def after_request(response):
self.session.remove()
return response
def create_tables(self):
Base.metadata.create_all(self.engine)
@ -26,3 +41,32 @@ class Database:
if table in meta.tables:
table = Table(table, meta, autoload_with=self.engine)
table.drop(self.engine, checkfirst=True)
def copy_table_structure(self, table, copy_table):
metadata = MetaData()
existing_table = Table(table, metadata, autoload_with=self.engine)
new_table = Table(
copy_table, metadata,
*(c.copy() for c in existing_table.columns),
extend_existing=True
)
new_table.create(self.engine, checkfirst=True)
for index in existing_table.indexes:
new_index_name = f"{index.name}_{copy_table}"
new_index = Index(
new_index_name,
*[new_table.c[col.name] for col in index.columns],
unique=index.unique
)
new_index.create(self.engine)
def rename_table(self, old_name, new_name):
with self.engine.connect() as conn:
conn.execute(text(f"DROP TABLE IF EXISTS {new_name}"))
conn.execute(text(f"ALTER TABLE {old_name} RENAME TO {new_name}"))

189
manage.py
View File

@ -1,26 +1,92 @@
import argparse
import csv
import copy
from datetime import datetime
from database import Database
from sqlalchemy import distinct
from sqlalchemy import MetaData
from sqlalchemy import Table
from sqlalchemy.orm import declarative_base
from sqlalchemy.orm import registry
from data.data_generator import StockDataGenerator
from data.indicators.sma import calculate_sma
from database import Database
from model.sma import Sma
from model.ticker import Ticker
from model.tickerSymbol import TickerSymbol
database = Database()
class CommandHandler:
def __init__(self, database=None, stock_generator=None):
self.database = database
self.stock_generator = stock_generator
self.mapper_registry = registry()
def load_database(chunk_size=100000):
database.drop_tables()
database.create_tables()
def create_tables(self):
self.database.create_tables()
with open("./data/output/tickers.csv") as file:
def generate_raw_daily_tickers(self, output_file):
self.stock_generator.generate_raw_data(output_file)
def load_raw_daily_tickers(self, path, chunk_size):
self._drop_table("ticker_rebuild")
copy_model = self._copy_existing_table_structure("ticker", "ticker_rebuild")
self._load_raw_ticker_data(copy_model, path, chunk_size)
self._rename_table("ticker", "ticker_backup")
self._rename_table("ticker_rebuild", "ticker")
def create_resource_sma(self):
self._drop_table("sma_rebuild")
copy_model = self._copy_existing_table_structure( "sma", "sma_rebuild")
self._generate_sma(copy_model)
self._rename_table("sma", "sma_backup")
self._rename_table("sma_rebuild", "sma")
def create_resource_ticker_symbols(self):
self._drop_table("ticker_symbols_rebuild")
copy_model = self._copy_existing_table_structure("ticker_symbol", "ticker_symbols_rebuild")
self._generate_ticker_symbols(copy_model)
self._rename_table("ticker_symbol", "ticker_symbol_backup")
self._rename_table("ticker_symbols_rebuild", "ticker_symbol")
def _copy_existing_table_structure(self, table_name, table_copy_name):
self.database.copy_table_structure(table_name, table_copy_name)
metadata = MetaData()
metadata.reflect(bind=self.database.engine)
copy_table = Table(table_copy_name, metadata, autoload_with=database.engine)
TempBase = declarative_base(metadata=metadata)
# Create a temporary table on the fly to load new data into.
class TempModel(TempBase):
__table__ = copy_table
return TempModel
def _rename_table(self, current_name, new_name):
self.database.rename_table(current_name, new_name)
def _drop_table(self, tablename):
self.database.drop_table(tablename)
def _load_raw_ticker_data(self, model, path, chunk_size):
with open(path, "r") as ticker_csv:
try:
rows = []
reader = csv.DictReader(file)
reader = csv.DictReader(ticker_csv)
total = 0
for row in reader:
ticker = Ticker(
ticker = model(
date=datetime.strptime(row.get("date"), "%Y-%m-%d").date(),
symbol=row.get("symbol"),
open=row.get("open"),
@ -38,29 +104,24 @@ def load_database(chunk_size=100000):
print(f"{total} rows inserted")
rows = []
# Insert any remaining rows.
# Remaining rows.
if rows:
database.session.bulk_save_objects(rows)
database.session.commit()
total += len(rows)
print(f"{total} inserted")
except Exception as e:
import pprint
pprint.pprint(e)
print(e)
def create_sma():
from sqlalchemy import distinct
database.drop_table("sma")
database.create_tables()
def _generate_sma(self, model):
sma_periods = [5, 10, 20, 50, 100, 200]
symbols = database.session.query(distinct(Ticker.symbol)).all()
symbols = self.database.session.query(distinct(Ticker.symbol)).all()
symbols = [symbol[0] for symbol in symbols]
count = 0
for symbol in symbols:
try:
daily_ticker_data = database.session.query(Ticker).filter_by(symbol=symbol[0]).all()
daily_ticker_data = self.database.session.query(Ticker).filter_by(symbol=symbol).all()
daily_ticker_data = [ticker.to_dict() for ticker in daily_ticker_data]
smas = {}
@ -70,24 +131,86 @@ def create_sma():
smas[period] = sma_data
for i in range(len(daily_ticker_data)):
row = Sma(
date=daily_ticker_data[i]['date'],
symbol=symbol[0],
sma_5=smas[5][i]['SMA'],
sma_10=smas[10][i]['SMA'],
sma_20=smas[20][i]['SMA'],
sma_50=smas[50][i]['SMA'],
sma_100=smas[100][i]['SMA'],
sma_200=smas[200][i]['SMA']
row = model(
date=datetime.strptime(daily_ticker_data[i]["date"], "%Y-%m-%d"),
symbol=symbol,
close=daily_ticker_data[i]["close"],
sma_5=smas[5][i]["SMA"],
sma_10=smas[10][i]["SMA"],
sma_20=smas[20][i]["SMA"],
sma_50=smas[50][i]["SMA"],
sma_100=smas[100][i]["SMA"],
sma_200=smas[200][i]["SMA"]
)
database.session.add(row)
database.session.commit()
self.database.session.add(row)
self.database.session.commit()
count += 1
print(f"finished {symbol[0]}: {count} of {len(symbols)}: time: {datetime.now()}")
print(f"finished {symbol}: {count} of {len(symbols)}: time: {datetime.now()}")
except Exception as e:
print(e)
database.session.rollback()
self.database.session.rollback()
def _generate_ticker_symbols(self, model):
symbols = self.database.session.query(distinct(Ticker.symbol)).all()
symbols = [symbol[0] for symbol in symbols]
for symbol in symbols:
ticker = model(symbol=symbol)
self.database.session.add(ticker)
self.database.session.commit()
def create_parser():
parser = argparse.ArgumentParser(description="Manager Script For Setup")
subparsers = parser.add_subparsers(dest="command", help="Available commands")
# Subparser for database command.
parser_database = subparsers.add_parser("database", help="Database management commands")
database_subparsers = parser_database.add_subparsers(dest="action", help="Database actions")
# Create action.
database_subparsers.add_parser("create", help="Create the database structure")
# Load action with chunk_size option.
parser_load = database_subparsers.add_parser("load", help="Load Raw Ticker Data into the database")
parser_load.add_argument("file", type=str, default="./data/output/tickers.csv", help="Location of the file")
parser_load.add_argument("--chunk_size", type=int, default=100000, help="Chunk size for loading data")
# Subparser for resource command.
parser_resource = subparsers.add_parser("resource", help="Resource management commands")
resource_subparsers = parser_resource.add_subparsers(dest="action", help="Resource actions")
resource_subparsers.add_parser("create_sma", help="Create SMA data")
resource_subparsers.add_parser("create_ticker_symbols", help="Get unique tickers")
# Subparser for generate command.
parser_generate = subparsers.add_parser("generate", help="Generate the Raw Ticker CSV")
parser_generate.add_argument("--output", type=str, default="./data/output/tickers.csv", help="Optional output file path")
return parser
if __name__ == "__main__":
create_sma()
parser = create_parser()
args = parser.parse_args()
database = Database(database_url="sqlite:///finance.db")
data_generator = StockDataGenerator()
handler = CommandHandler(
database=database,
stock_generator=data_generator
)
if args.command == "database":
if args.action == "create":
handler.create_tables()
elif args.action == "load":
handler.load_raw_daily_tickers(args.file, args.chunk_size)
elif args.command == "resource":
if args.action == "create_sma":
handler.create_resource_sma()
elif args.action == "create_ticker_symbols":
handler.create_resource_ticker_symbols()
elif args.command == "generate":
handler.generate_raw_daily_tickers(args.output)
else:
parser.print_help()

View File

@ -14,9 +14,24 @@ class Sma(Base):
id = Column(Integer, primary_key=True, autoincrement=True)
date = Column(DateTime)
symbol = Column(String(15), nullable=False)
close = Column(Float)
sma_5 = Column(Float)
sma_10 = Column(Float)
sma_20 = Column(Float)
sma_50 = Column(Float)
sma_100 = Column(Float)
sma_200 = Column(Float)
def to_dict(self):
return {
"date": self.date.strftime("%Y-%m-%d"),
"symbol": self.symbol,
"close": self.close,
"sma_5": self.sma_5,
"sma_10": self.sma_10,
"sma_20": self.sma_20,
"sma_50": self.sma_50,
"sma_100": self.sma_100,
"sma_200": self.sma_200,
}

View File

@ -23,7 +23,7 @@ class Ticker(Base):
def to_dict(self):
return {
"date": self.date,
"date": self.date.strftime("%Y-%m-%d"),
"symbol": self.symbol,
"open": self.open,
"high": self.high,

21
model/tickerSymbol.py Normal file
View File

@ -0,0 +1,21 @@
from datetime import datetime
from sqlalchemy import Column
from sqlalchemy import DateTime
from sqlalchemy import Float
from sqlalchemy import Integer
from sqlalchemy import String
from model.base import Base
class TickerSymbol(Base):
__tablename__ = 'ticker_symbol'
id = Column(Integer, primary_key=True, autoincrement=True)
symbol = Column(String(15), nullable=False, unique=True)
def to_dict(self):
return {
"symbol": self.symbol,
}

51
rest_api.py Normal file
View File

@ -0,0 +1,51 @@
from flask import Flask
from database import Database
def create_app(config_object):
app = Flask(__name__)
app.config.from_object(config_object)
app.json.sort_keys = False
app.url_map.strict_slashes = False
database = Database()
database.init_app(app)
configure_blueprints(app, database)
return app
def configure_blueprints(app, database):
from api.views.historical import HistoricalView
from api.views.sma import SmaView
from api.views.ticker import TickerView
from api.service.historical import HistoricalService
from api.service.sma import SmaService
from api.service.ticker import TickerService
historical_service = HistoricalService(database=database)
historical_view = HistoricalView.as_view("historical", service=historical_service)
historical_url = "/api/v1/historical/<string:ticker>"
sma_service = SmaService(database=database)
sma_view = SmaView.as_view("sma", service=sma_service)
sma_url = "/api/v1/indicators/sma/<string:ticker>"
tickers_service = TickerService(database=database)
tickers_view = TickerView.as_view("tickers", service=tickers_service)
tickers_url = "/api/v1/tickers"
# Register routes with HTTP methods.
app.add_url_rule(historical_url, view_func=historical_view, methods=["GET"])
app.add_url_rule(sma_url, view_func=sma_view, methods=["GET"])
app.add_url_rule(tickers_url, view_func=tickers_view, methods=["GET"])
if __name__ == "__main__":
from api.config import ApiConfig
app = create_app(ApiConfig)
app.run(host="localhost", port=3000, debug=True)