Update finance api project to include rest api (#2)

Reviewed-on: #2
This commit is contained in:
thecodebranch 2024-07-04 15:51:34 +00:00
parent 8a4f34855f
commit dce8860e2e
18 changed files with 559 additions and 82 deletions

2
api/config.py Normal file
View File

@ -0,0 +1,2 @@
class ApiConfig:
DATABASE_URI = "sqlite:///finance.db"

23
api/errors.py Normal file
View File

@ -0,0 +1,23 @@
class Error(Exception):
def __init__(self, value=""):
if not hasattr(self, "value"):
self.value = value
def __str__(self):
return repr(self.value)
class InvalidSma(Error):
message = "Invalid SMA period."
class ResourceNotFound(Error):
message = "Resource not found."
class DateFormatError(Error):
message = "Error with date parameters."
class Messages:
INTERNAL_SERVER_ERROR = "Server error. please try again later."

3
api/service/base.py Normal file
View File

@ -0,0 +1,3 @@
class BaseService:
def __init__(self, database=None):
self.database = database

45
api/service/historical.py Normal file
View File

@ -0,0 +1,45 @@
from datetime import datetime
from .base import BaseService
from api.errors import DateFormatError
from api.errors import ResourceNotFound
from model.ticker import Ticker
def validate_date(date_str):
try:
date_obj = datetime.strptime(date_str, "%Y-%m-%d")
return date_obj
except ValueError:
raise DateFormatError
class HistoricalService(BaseService):
def get_historical_ticker(self, symbol, params):
from_date = params.get("from_date", "2000-01-01")
to_date = params.get("to_date", "2001-01-01")
from_date = validate_date(from_date)
to_date = validate_date(to_date)
historical = (
self.database.session.query(Ticker)
.filter(Ticker.symbol == symbol)
.filter(Ticker.date >= from_date)
.filter(Ticker.date <= to_date)
.all()
)
if not historical:
raise ResourceNotFound
output = {}
data = []
for ticker in historical:
ticker = ticker.to_dict()
ticker.pop("symbol")
data.append(ticker)
output["ticker"] = symbol
output["data"] = data
return output

79
api/service/sma.py Normal file
View File

@ -0,0 +1,79 @@
from .base import BaseService
from model.sma import Sma
from api.errors import ResourceNotFound
from api.errors import InvalidSma
class SmaService(BaseService):
def get_sma(self, ticker, params):
# Validate SMA periods and get corresponding columns
sma_periods = params.get("period")
sma_columns = self._validate_and_get_columns(sma_periods)
from_date = params.get("from_date", "2000-01-01")
to_date = params.get("to_date", "2001-01-01")
dates = {"from_date": from_date, "to_date": to_date}
# Query the database for the necessary columns
data = self._query_sma_data(ticker, sma_columns, dates)
if not data:
raise ResourceNotFound
data = self._format_results(data, sma_columns)
result = {
"ticker": ticker,
"data": data
}
return result
def _validate_and_get_columns(self, sma_periods):
valid_sma_columns = {
"5": Sma.sma_5,
"10": Sma.sma_10,
"20": Sma.sma_20,
"50": Sma.sma_50,
"100": Sma.sma_100,
"200": Sma.sma_200,
}
sma_columns = []
for period in sma_periods.split(","):
if period not in valid_sma_columns.keys():
raise InvalidSma
sma_columns.append(valid_sma_columns[period])
return sma_columns
def _query_sma_data(self, ticker, sma_columns, dates):
# Start with the base query selecting the date and symbol
query = self.database.session.query(Sma.symbol, Sma.date, Sma.close)
# Add the SMA columns dynamically
query = query.add_columns(*sma_columns)
# Filter by the ticker symbol
query = query.filter(Sma.symbol == ticker)
query = query.filter(Sma.date >= dates.get("from_date"))
query = query.filter(Sma.date <= dates.get("to_date"))
# Execute the query and return the results
return query.all()
def _format_results(self, data, sma_columns):
# Format the results into a list of dictionaries
result = []
for row in data:
row_dict = {
"date": row.date.strftime("%Y-%m-%d"),
"close": row.close,
}
for i, column in enumerate(sma_columns):
# +2 because the first two columns are date and symbol
sma = row[i + 3]
row_dict[column.key] = sma
result.append(row_dict)
return result

11
api/service/ticker.py Normal file
View File

@ -0,0 +1,11 @@
from model.tickerSymbol import TickerSymbol
from .base import BaseService
class TickerService(BaseService):
def get_tickers(self):
try:
tickers = self.database.session.query(TickerSymbol).all()
return [ticker.symbol for ticker in tickers]
except Exception as e:
pass

6
api/views/base.py Normal file
View File

@ -0,0 +1,6 @@
from flask.views import MethodView
class BaseView(MethodView):
def __init__(self, service=None):
self.service=service

19
api/views/historical.py Normal file
View File

@ -0,0 +1,19 @@
from flask import jsonify
from flask import request
from .base import BaseView
from api.errors import DateFormatError
from api.errors import Messages
from api.errors import ResourceNotFound
class HistoricalView(BaseView):
def get(self, ticker):
try:
data = self.service.get_historical_ticker(ticker, params=request.args)
return jsonify(data), 200
except ResourceNotFound as e:
return jsonify({"error": e.message}), 404
except DateFormatError as e:
return jsonify({"error": e.message}), 400
except Exception as e:
return jsonify({"error": Messages.INTERNAL_SERVER_ERROR}), 500

19
api/views/sma.py Normal file
View File

@ -0,0 +1,19 @@
from flask import jsonify
from flask import request
from .base import BaseView
from api.errors import ResourceNotFound
from api.errors import InvalidSma
from api.errors import Messages
class SmaView(BaseView):
def get(self, ticker):
try:
data = self.service.get_sma(ticker, request.args)
return jsonify({"data": data}), 200
except InvalidSma as e:
return jsonify({"error": e.message}), 400
except ResourceNotFound as e:
return jsonify({"error": e.message}), 400
except Exception as e:
return jsonify({"error": Messages.INTERNAL_SERVER_ERROR}), 500

12
api/views/ticker.py Normal file
View File

@ -0,0 +1,12 @@
from flask import jsonify
from .base import BaseView
from api.errors import Messages
class TickerView(BaseView):
def get(self):
try:
data = self.service.get_tickers()
return jsonify({"data": data}), 200
except Exception as e:
return jsonify({"error": Messages.INTERNAL_SERVER_ERROR}), 500

View File

@ -29,14 +29,14 @@ class StockDataGenerator:
self.volume_adjustment_factor = volume_adjustment_factor self.volume_adjustment_factor = volume_adjustment_factor
def generate_raw_data(self, start_date="2000-01-01", end_date="2024-01-01"): def generate_raw_data(self, output_path=None, start_date="2000-01-01", end_date="2024-01-01"):
tickers = self._generate_fake_tickers() tickers = self._generate_fake_tickers()
dates = self._generate_dates() dates = self._generate_dates()
ticker_count = 0 ticker_count = 0
for ticker in tickers: for ticker in tickers:
daily_ticker_data = self._generate_stock_data(ticker, dates) daily_ticker_data = self._generate_stock_data(ticker, dates)
self._write_to_csv(daily_ticker_data) self._write_to_csv(daily_ticker_data, output_path)
ticker_count += 1 ticker_count += 1
print(f"Generated data for: {ticker} and is {ticker_count} of {len(tickers)}") print(f"Generated data for: {ticker} and is {ticker_count} of {len(tickers)}")
@ -109,8 +109,12 @@ class StockDataGenerator:
return stock_data return stock_data
def _write_to_csv(self, data): def _write_to_csv(self, data, output=None):
with open(self.csv_output_file, "a") as file: if output:
output_file = output
else:
output_file = self.csv_output_file
with open(output_file, "a") as file:
fieldnames = ["date", "symbol", "open", "high", "low", "close", "volume"] fieldnames = ["date", "symbol", "open", "high", "low", "close", "volume"]
writer = csv.DictWriter(file, fieldnames=fieldnames) writer = csv.DictWriter(file, fieldnames=fieldnames)
if file.tell() == 0: if file.tell() == 0:

View File

@ -13,7 +13,7 @@ def calculate_sma(data, period):
sma_values.append(None) sma_values.append(None)
else: else:
sma = sum(close_prices[i + 1 - period:i + 1]) / period sma = sum(close_prices[i + 1 - period:i + 1]) / period
sma_values.append(sma) sma_values.append(round(sma, 3))
for i in range(len(data)): for i in range(len(data)):
data[i]['SMA'] = sma_values[i] data[i]['SMA'] = sma_values[i]

View File

@ -1,9 +1,12 @@
from datetime import datetime
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import scoped_session from sqlalchemy.orm import scoped_session
from model.base import Base from model.base import Base
from sqlalchemy import MetaData, Table from sqlalchemy import MetaData
from sqlalchemy import Table
from sqlalchemy import Index
from sqlalchemy import text
class Database: class Database:
@ -11,6 +14,18 @@ class Database:
self.engine = create_engine(database_url) self.engine = create_engine(database_url)
self.session = scoped_session(sessionmaker(bind=self.engine)) self.session = scoped_session(sessionmaker(bind=self.engine))
def init_app(self, app):
self.engine = create_engine((app.config["DATABASE_URI"]))
self.session = scoped_session(sessionmaker(bind=self.engine))
# flask request handlers.
@app.after_request
def after_request(response):
self.session.remove()
return response
def create_tables(self): def create_tables(self):
Base.metadata.create_all(self.engine) Base.metadata.create_all(self.engine)
@ -25,4 +40,33 @@ class Database:
meta.reflect(bind=self.engine) meta.reflect(bind=self.engine)
if table in meta.tables: if table in meta.tables:
table = Table(table, meta, autoload_with=self.engine) table = Table(table, meta, autoload_with=self.engine)
table.drop(self.engine, checkfirst=True) table.drop(self.engine, checkfirst=True)
def copy_table_structure(self, table, copy_table):
metadata = MetaData()
existing_table = Table(table, metadata, autoload_with=self.engine)
new_table = Table(
copy_table, metadata,
*(c.copy() for c in existing_table.columns),
extend_existing=True
)
new_table.create(self.engine, checkfirst=True)
for index in existing_table.indexes:
new_index_name = f"{index.name}_{copy_table}"
new_index = Index(
new_index_name,
*[new_table.c[col.name] for col in index.columns],
unique=index.unique
)
new_index.create(self.engine)
def rename_table(self, old_name, new_name):
with self.engine.connect() as conn:
conn.execute(text(f"DROP TABLE IF EXISTS {new_name}"))
conn.execute(text(f"ALTER TABLE {old_name} RENAME TO {new_name}"))

267
manage.py
View File

@ -1,93 +1,216 @@
import argparse
import csv import csv
import copy import copy
from datetime import datetime from datetime import datetime
from database import Database from sqlalchemy import distinct
from sqlalchemy import MetaData
from sqlalchemy import Table
from sqlalchemy.orm import declarative_base
from sqlalchemy.orm import registry
from data.data_generator import StockDataGenerator
from data.indicators.sma import calculate_sma from data.indicators.sma import calculate_sma
from database import Database
from model.sma import Sma from model.sma import Sma
from model.ticker import Ticker from model.ticker import Ticker
from model.tickerSymbol import TickerSymbol
database = Database() class CommandHandler:
def __init__(self, database=None, stock_generator=None):
self.database = database
self.stock_generator = stock_generator
self.mapper_registry = registry()
def load_database(chunk_size=100000): def create_tables(self):
database.drop_tables() self.database.create_tables()
database.create_tables()
with open("./data/output/tickers.csv") as file:
try:
rows = []
reader = csv.DictReader(file)
total = 0
for row in reader:
ticker = Ticker(
date=datetime.strptime(row.get("date"), "%Y-%m-%d").date(),
symbol=row.get("symbol"),
open=row.get("open"),
high=row.get("high"),
low=row.get("low"),
close=row.get("close"),
volume=row.get("volume")
)
rows.append(ticker)
if len(rows) >= chunk_size: def generate_raw_daily_tickers(self, output_file):
self.stock_generator.generate_raw_data(output_file)
def load_raw_daily_tickers(self, path, chunk_size):
self._drop_table("ticker_rebuild")
copy_model = self._copy_existing_table_structure("ticker", "ticker_rebuild")
self._load_raw_ticker_data(copy_model, path, chunk_size)
self._rename_table("ticker", "ticker_backup")
self._rename_table("ticker_rebuild", "ticker")
def create_resource_sma(self):
self._drop_table("sma_rebuild")
copy_model = self._copy_existing_table_structure( "sma", "sma_rebuild")
self._generate_sma(copy_model)
self._rename_table("sma", "sma_backup")
self._rename_table("sma_rebuild", "sma")
def create_resource_ticker_symbols(self):
self._drop_table("ticker_symbols_rebuild")
copy_model = self._copy_existing_table_structure("ticker_symbol", "ticker_symbols_rebuild")
self._generate_ticker_symbols(copy_model)
self._rename_table("ticker_symbol", "ticker_symbol_backup")
self._rename_table("ticker_symbols_rebuild", "ticker_symbol")
def _copy_existing_table_structure(self, table_name, table_copy_name):
self.database.copy_table_structure(table_name, table_copy_name)
metadata = MetaData()
metadata.reflect(bind=self.database.engine)
copy_table = Table(table_copy_name, metadata, autoload_with=database.engine)
TempBase = declarative_base(metadata=metadata)
# Create a temporary table on the fly to load new data into.
class TempModel(TempBase):
__table__ = copy_table
return TempModel
def _rename_table(self, current_name, new_name):
self.database.rename_table(current_name, new_name)
def _drop_table(self, tablename):
self.database.drop_table(tablename)
def _load_raw_ticker_data(self, model, path, chunk_size):
with open(path, "r") as ticker_csv:
try:
rows = []
reader = csv.DictReader(ticker_csv)
total = 0
for row in reader:
ticker = model(
date=datetime.strptime(row.get("date"), "%Y-%m-%d").date(),
symbol=row.get("symbol"),
open=row.get("open"),
high=row.get("high"),
low=row.get("low"),
close=row.get("close"),
volume=row.get("volume")
)
rows.append(ticker)
if len(rows) >= chunk_size:
database.session.bulk_save_objects(rows)
database.session.commit()
total += chunk_size
print(f"{total} rows inserted")
rows = []
# Remaining rows.
if rows:
database.session.bulk_save_objects(rows) database.session.bulk_save_objects(rows)
database.session.commit() database.session.commit()
total += chunk_size total += len(rows)
print(f"{total} rows inserted") print(f"{total} inserted")
rows = [] except Exception as e:
print(e)
# Insert any remaining rows.
if rows:
database.session.bulk_save_objects(rows)
database.session.commit()
total += len(rows)
print(f"{total} inserted")
except Exception as e:
import pprint
pprint.pprint(e)
def create_sma(): def _generate_sma(self, model):
from sqlalchemy import distinct sma_periods = [5, 10, 20, 50, 100, 200]
symbols = self.database.session.query(distinct(Ticker.symbol)).all()
symbols = [symbol[0] for symbol in symbols]
count = 0
for symbol in symbols:
try:
daily_ticker_data = self.database.session.query(Ticker).filter_by(symbol=symbol).all()
daily_ticker_data = [ticker.to_dict() for ticker in daily_ticker_data]
smas = {}
for period in sma_periods:
data = copy.deepcopy(daily_ticker_data)
sma_data = calculate_sma(data, period)
smas[period] = sma_data
for i in range(len(daily_ticker_data)):
row = model(
date=datetime.strptime(daily_ticker_data[i]["date"], "%Y-%m-%d"),
symbol=symbol,
close=daily_ticker_data[i]["close"],
sma_5=smas[5][i]["SMA"],
sma_10=smas[10][i]["SMA"],
sma_20=smas[20][i]["SMA"],
sma_50=smas[50][i]["SMA"],
sma_100=smas[100][i]["SMA"],
sma_200=smas[200][i]["SMA"]
)
self.database.session.add(row)
self.database.session.commit()
count += 1
print(f"finished {symbol}: {count} of {len(symbols)}: time: {datetime.now()}")
except Exception as e:
print(e)
self.database.session.rollback()
database.drop_table("sma")
database.create_tables()
sma_periods = [5, 10, 20, 50, 100, 200] def _generate_ticker_symbols(self, model):
symbols = database.session.query(distinct(Ticker.symbol)).all() symbols = self.database.session.query(distinct(Ticker.symbol)).all()
count = 0 symbols = [symbol[0] for symbol in symbols]
for symbol in symbols: for symbol in symbols:
try: ticker = model(symbol=symbol)
daily_ticker_data = database.session.query(Ticker).filter_by(symbol=symbol[0]).all() self.database.session.add(ticker)
daily_ticker_data = [ticker.to_dict() for ticker in daily_ticker_data] self.database.session.commit()
smas = {}
for period in sma_periods: def create_parser():
data = copy.deepcopy(daily_ticker_data) parser = argparse.ArgumentParser(description="Manager Script For Setup")
sma_data = calculate_sma(data, period) subparsers = parser.add_subparsers(dest="command", help="Available commands")
smas[period] = sma_data
# Subparser for database command.
for i in range(len(daily_ticker_data)): parser_database = subparsers.add_parser("database", help="Database management commands")
row = Sma( database_subparsers = parser_database.add_subparsers(dest="action", help="Database actions")
date=daily_ticker_data[i]['date'],
symbol=symbol[0], # Create action.
sma_5=smas[5][i]['SMA'], database_subparsers.add_parser("create", help="Create the database structure")
sma_10=smas[10][i]['SMA'],
sma_20=smas[20][i]['SMA'], # Load action with chunk_size option.
sma_50=smas[50][i]['SMA'], parser_load = database_subparsers.add_parser("load", help="Load Raw Ticker Data into the database")
sma_100=smas[100][i]['SMA'], parser_load.add_argument("file", type=str, default="./data/output/tickers.csv", help="Location of the file")
sma_200=smas[200][i]['SMA'] parser_load.add_argument("--chunk_size", type=int, default=100000, help="Chunk size for loading data")
)
database.session.add(row) # Subparser for resource command.
database.session.commit() parser_resource = subparsers.add_parser("resource", help="Resource management commands")
count += 1 resource_subparsers = parser_resource.add_subparsers(dest="action", help="Resource actions")
print(f"finished {symbol[0]}: {count} of {len(symbols)}: time: {datetime.now()}") resource_subparsers.add_parser("create_sma", help="Create SMA data")
except Exception as e: resource_subparsers.add_parser("create_ticker_symbols", help="Get unique tickers")
print(e)
database.session.rollback() # Subparser for generate command.
parser_generate = subparsers.add_parser("generate", help="Generate the Raw Ticker CSV")
parser_generate.add_argument("--output", type=str, default="./data/output/tickers.csv", help="Optional output file path")
return parser
if __name__ == "__main__": if __name__ == "__main__":
create_sma() parser = create_parser()
args = parser.parse_args()
database = Database(database_url="sqlite:///finance.db")
data_generator = StockDataGenerator()
handler = CommandHandler(
database=database,
stock_generator=data_generator
)
if args.command == "database":
if args.action == "create":
handler.create_tables()
elif args.action == "load":
handler.load_raw_daily_tickers(args.file, args.chunk_size)
elif args.command == "resource":
if args.action == "create_sma":
handler.create_resource_sma()
elif args.action == "create_ticker_symbols":
handler.create_resource_ticker_symbols()
elif args.command == "generate":
handler.generate_raw_daily_tickers(args.output)
else:
parser.print_help()

View File

@ -14,9 +14,24 @@ class Sma(Base):
id = Column(Integer, primary_key=True, autoincrement=True) id = Column(Integer, primary_key=True, autoincrement=True)
date = Column(DateTime) date = Column(DateTime)
symbol = Column(String(15), nullable=False) symbol = Column(String(15), nullable=False)
close = Column(Float)
sma_5 = Column(Float) sma_5 = Column(Float)
sma_10 = Column(Float) sma_10 = Column(Float)
sma_20 = Column(Float) sma_20 = Column(Float)
sma_50 = Column(Float) sma_50 = Column(Float)
sma_100 = Column(Float) sma_100 = Column(Float)
sma_200 = Column(Float) sma_200 = Column(Float)
def to_dict(self):
return {
"date": self.date.strftime("%Y-%m-%d"),
"symbol": self.symbol,
"close": self.close,
"sma_5": self.sma_5,
"sma_10": self.sma_10,
"sma_20": self.sma_20,
"sma_50": self.sma_50,
"sma_100": self.sma_100,
"sma_200": self.sma_200,
}

View File

@ -23,7 +23,7 @@ class Ticker(Base):
def to_dict(self): def to_dict(self):
return { return {
"date": self.date, "date": self.date.strftime("%Y-%m-%d"),
"symbol": self.symbol, "symbol": self.symbol,
"open": self.open, "open": self.open,
"high": self.high, "high": self.high,

21
model/tickerSymbol.py Normal file
View File

@ -0,0 +1,21 @@
from datetime import datetime
from sqlalchemy import Column
from sqlalchemy import DateTime
from sqlalchemy import Float
from sqlalchemy import Integer
from sqlalchemy import String
from model.base import Base
class TickerSymbol(Base):
__tablename__ = 'ticker_symbol'
id = Column(Integer, primary_key=True, autoincrement=True)
symbol = Column(String(15), nullable=False, unique=True)
def to_dict(self):
return {
"symbol": self.symbol,
}

51
rest_api.py Normal file
View File

@ -0,0 +1,51 @@
from flask import Flask
from database import Database
def create_app(config_object):
app = Flask(__name__)
app.config.from_object(config_object)
app.json.sort_keys = False
app.url_map.strict_slashes = False
database = Database()
database.init_app(app)
configure_blueprints(app, database)
return app
def configure_blueprints(app, database):
from api.views.historical import HistoricalView
from api.views.sma import SmaView
from api.views.ticker import TickerView
from api.service.historical import HistoricalService
from api.service.sma import SmaService
from api.service.ticker import TickerService
historical_service = HistoricalService(database=database)
historical_view = HistoricalView.as_view("historical", service=historical_service)
historical_url = "/api/v1/historical/<string:ticker>"
sma_service = SmaService(database=database)
sma_view = SmaView.as_view("sma", service=sma_service)
sma_url = "/api/v1/indicators/sma/<string:ticker>"
tickers_service = TickerService(database=database)
tickers_view = TickerView.as_view("tickers", service=tickers_service)
tickers_url = "/api/v1/tickers"
# Register routes with HTTP methods.
app.add_url_rule(historical_url, view_func=historical_view, methods=["GET"])
app.add_url_rule(sma_url, view_func=sma_view, methods=["GET"])
app.add_url_rule(tickers_url, view_func=tickers_view, methods=["GET"])
if __name__ == "__main__":
from api.config import ApiConfig
app = create_app(ApiConfig)
app.run(host="localhost", port=3000, debug=True)