diff --git a/api/config.py b/api/config.py new file mode 100644 index 0000000..914ea28 --- /dev/null +++ b/api/config.py @@ -0,0 +1,2 @@ +class ApiConfig: + DATABASE_URI = "sqlite:///finance.db" \ No newline at end of file diff --git a/api/errors.py b/api/errors.py new file mode 100644 index 0000000..ae4c942 --- /dev/null +++ b/api/errors.py @@ -0,0 +1,23 @@ +class Error(Exception): + def __init__(self, value=""): + if not hasattr(self, "value"): + self.value = value + + def __str__(self): + return repr(self.value) + + +class InvalidSma(Error): + message = "Invalid SMA period." + + +class ResourceNotFound(Error): + message = "Resource not found." + + +class DateFormatError(Error): + message = "Error with date parameters." + + +class Messages: + INTERNAL_SERVER_ERROR = "Server error. please try again later." diff --git a/api/service/base.py b/api/service/base.py new file mode 100644 index 0000000..f097f81 --- /dev/null +++ b/api/service/base.py @@ -0,0 +1,3 @@ +class BaseService: + def __init__(self, database=None): + self.database = database \ No newline at end of file diff --git a/api/service/historical.py b/api/service/historical.py new file mode 100644 index 0000000..5c73969 --- /dev/null +++ b/api/service/historical.py @@ -0,0 +1,45 @@ +from datetime import datetime +from .base import BaseService +from api.errors import DateFormatError +from api.errors import ResourceNotFound +from model.ticker import Ticker + + +def validate_date(date_str): + try: + date_obj = datetime.strptime(date_str, "%Y-%m-%d") + return date_obj + except ValueError: + raise DateFormatError + + +class HistoricalService(BaseService): + def get_historical_ticker(self, symbol, params): + from_date = params.get("from_date", "2000-01-01") + to_date = params.get("to_date", "2001-01-01") + + from_date = validate_date(from_date) + to_date = validate_date(to_date) + + historical = ( + self.database.session.query(Ticker) + .filter(Ticker.symbol == symbol) + .filter(Ticker.date >= from_date) + .filter(Ticker.date <= to_date) + .all() + ) + + if not historical: + raise ResourceNotFound + + output = {} + + data = [] + for ticker in historical: + ticker = ticker.to_dict() + ticker.pop("symbol") + data.append(ticker) + + output["ticker"] = symbol + output["data"] = data + return output \ No newline at end of file diff --git a/api/service/sma.py b/api/service/sma.py new file mode 100644 index 0000000..8f10a11 --- /dev/null +++ b/api/service/sma.py @@ -0,0 +1,79 @@ +from .base import BaseService +from model.sma import Sma +from api.errors import ResourceNotFound +from api.errors import InvalidSma + + +class SmaService(BaseService): + def get_sma(self, ticker, params): + # Validate SMA periods and get corresponding columns + sma_periods = params.get("period") + sma_columns = self._validate_and_get_columns(sma_periods) + + from_date = params.get("from_date", "2000-01-01") + to_date = params.get("to_date", "2001-01-01") + + dates = {"from_date": from_date, "to_date": to_date} + # Query the database for the necessary columns + data = self._query_sma_data(ticker, sma_columns, dates) + + if not data: + raise ResourceNotFound + + data = self._format_results(data, sma_columns) + + result = { + "ticker": ticker, + "data": data + } + return result + + + def _validate_and_get_columns(self, sma_periods): + valid_sma_columns = { + "5": Sma.sma_5, + "10": Sma.sma_10, + "20": Sma.sma_20, + "50": Sma.sma_50, + "100": Sma.sma_100, + "200": Sma.sma_200, + } + + sma_columns = [] + for period in sma_periods.split(","): + if period not in valid_sma_columns.keys(): + raise InvalidSma + sma_columns.append(valid_sma_columns[period]) + return sma_columns + + + def _query_sma_data(self, ticker, sma_columns, dates): + # Start with the base query selecting the date and symbol + query = self.database.session.query(Sma.symbol, Sma.date, Sma.close) + + # Add the SMA columns dynamically + query = query.add_columns(*sma_columns) + + # Filter by the ticker symbol + query = query.filter(Sma.symbol == ticker) + query = query.filter(Sma.date >= dates.get("from_date")) + query = query.filter(Sma.date <= dates.get("to_date")) + + # Execute the query and return the results + return query.all() + + + def _format_results(self, data, sma_columns): + # Format the results into a list of dictionaries + result = [] + for row in data: + row_dict = { + "date": row.date.strftime("%Y-%m-%d"), + "close": row.close, + } + for i, column in enumerate(sma_columns): + # +2 because the first two columns are date and symbol + sma = row[i + 3] + row_dict[column.key] = sma + result.append(row_dict) + return result \ No newline at end of file diff --git a/api/service/ticker.py b/api/service/ticker.py new file mode 100644 index 0000000..b3b11e2 --- /dev/null +++ b/api/service/ticker.py @@ -0,0 +1,11 @@ +from model.tickerSymbol import TickerSymbol +from .base import BaseService + + +class TickerService(BaseService): + def get_tickers(self): + try: + tickers = self.database.session.query(TickerSymbol).all() + return [ticker.symbol for ticker in tickers] + except Exception as e: + pass \ No newline at end of file diff --git a/api/views/base.py b/api/views/base.py new file mode 100644 index 0000000..653b723 --- /dev/null +++ b/api/views/base.py @@ -0,0 +1,6 @@ +from flask.views import MethodView + + +class BaseView(MethodView): + def __init__(self, service=None): + self.service=service \ No newline at end of file diff --git a/api/views/historical.py b/api/views/historical.py new file mode 100644 index 0000000..02de57a --- /dev/null +++ b/api/views/historical.py @@ -0,0 +1,19 @@ +from flask import jsonify +from flask import request +from .base import BaseView +from api.errors import DateFormatError +from api.errors import Messages +from api.errors import ResourceNotFound + + +class HistoricalView(BaseView): + def get(self, ticker): + try: + data = self.service.get_historical_ticker(ticker, params=request.args) + return jsonify(data), 200 + except ResourceNotFound as e: + return jsonify({"error": e.message}), 404 + except DateFormatError as e: + return jsonify({"error": e.message}), 400 + except Exception as e: + return jsonify({"error": Messages.INTERNAL_SERVER_ERROR}), 500 diff --git a/api/views/sma.py b/api/views/sma.py new file mode 100644 index 0000000..a5fd215 --- /dev/null +++ b/api/views/sma.py @@ -0,0 +1,19 @@ +from flask import jsonify +from flask import request +from .base import BaseView +from api.errors import ResourceNotFound +from api.errors import InvalidSma +from api.errors import Messages + + +class SmaView(BaseView): + def get(self, ticker): + try: + data = self.service.get_sma(ticker, request.args) + return jsonify({"data": data}), 200 + except InvalidSma as e: + return jsonify({"error": e.message}), 400 + except ResourceNotFound as e: + return jsonify({"error": e.message}), 400 + except Exception as e: + return jsonify({"error": Messages.INTERNAL_SERVER_ERROR}), 500 diff --git a/api/views/ticker.py b/api/views/ticker.py new file mode 100644 index 0000000..adb3e39 --- /dev/null +++ b/api/views/ticker.py @@ -0,0 +1,12 @@ +from flask import jsonify +from .base import BaseView +from api.errors import Messages + + +class TickerView(BaseView): + def get(self): + try: + data = self.service.get_tickers() + return jsonify({"data": data}), 200 + except Exception as e: + return jsonify({"error": Messages.INTERNAL_SERVER_ERROR}), 500 diff --git a/data/data_generator.py b/data/data_generator.py index d392f75..25dae17 100644 --- a/data/data_generator.py +++ b/data/data_generator.py @@ -29,14 +29,14 @@ class StockDataGenerator: self.volume_adjustment_factor = volume_adjustment_factor - def generate_raw_data(self, start_date="2000-01-01", end_date="2024-01-01"): + def generate_raw_data(self, output_path=None, start_date="2000-01-01", end_date="2024-01-01"): tickers = self._generate_fake_tickers() dates = self._generate_dates() ticker_count = 0 for ticker in tickers: daily_ticker_data = self._generate_stock_data(ticker, dates) - self._write_to_csv(daily_ticker_data) + self._write_to_csv(daily_ticker_data, output_path) ticker_count += 1 print(f"Generated data for: {ticker} and is {ticker_count} of {len(tickers)}") @@ -109,8 +109,12 @@ class StockDataGenerator: return stock_data - def _write_to_csv(self, data): - with open(self.csv_output_file, "a") as file: + def _write_to_csv(self, data, output=None): + if output: + output_file = output + else: + output_file = self.csv_output_file + with open(output_file, "a") as file: fieldnames = ["date", "symbol", "open", "high", "low", "close", "volume"] writer = csv.DictWriter(file, fieldnames=fieldnames) if file.tell() == 0: diff --git a/data/indicators/sma.py b/data/indicators/sma.py index 528e1d2..55ca9ad 100644 --- a/data/indicators/sma.py +++ b/data/indicators/sma.py @@ -13,7 +13,7 @@ def calculate_sma(data, period): sma_values.append(None) else: sma = sum(close_prices[i + 1 - period:i + 1]) / period - sma_values.append(sma) + sma_values.append(round(sma, 3)) for i in range(len(data)): data[i]['SMA'] = sma_values[i] diff --git a/database/__init__.py b/database/__init__.py index f97baf5..9e2b32e 100644 --- a/database/__init__.py +++ b/database/__init__.py @@ -1,9 +1,12 @@ +from datetime import datetime from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import scoped_session from model.base import Base -from sqlalchemy import MetaData, Table - +from sqlalchemy import MetaData +from sqlalchemy import Table +from sqlalchemy import Index +from sqlalchemy import text class Database: @@ -11,6 +14,18 @@ class Database: self.engine = create_engine(database_url) self.session = scoped_session(sessionmaker(bind=self.engine)) + + def init_app(self, app): + self.engine = create_engine((app.config["DATABASE_URI"])) + self.session = scoped_session(sessionmaker(bind=self.engine)) + + + # flask request handlers. + @app.after_request + def after_request(response): + self.session.remove() + return response + def create_tables(self): Base.metadata.create_all(self.engine) @@ -25,4 +40,33 @@ class Database: meta.reflect(bind=self.engine) if table in meta.tables: table = Table(table, meta, autoload_with=self.engine) - table.drop(self.engine, checkfirst=True) \ No newline at end of file + table.drop(self.engine, checkfirst=True) + + + def copy_table_structure(self, table, copy_table): + + metadata = MetaData() + existing_table = Table(table, metadata, autoload_with=self.engine) + + new_table = Table( + copy_table, metadata, + *(c.copy() for c in existing_table.columns), + extend_existing=True + ) + + new_table.create(self.engine, checkfirst=True) + + for index in existing_table.indexes: + new_index_name = f"{index.name}_{copy_table}" + new_index = Index( + new_index_name, + *[new_table.c[col.name] for col in index.columns], + unique=index.unique + ) + new_index.create(self.engine) + + + def rename_table(self, old_name, new_name): + with self.engine.connect() as conn: + conn.execute(text(f"DROP TABLE IF EXISTS {new_name}")) + conn.execute(text(f"ALTER TABLE {old_name} RENAME TO {new_name}")) \ No newline at end of file diff --git a/manage.py b/manage.py index e22dedb..7f88c2a 100644 --- a/manage.py +++ b/manage.py @@ -1,93 +1,216 @@ +import argparse import csv import copy from datetime import datetime -from database import Database +from sqlalchemy import distinct +from sqlalchemy import MetaData +from sqlalchemy import Table +from sqlalchemy.orm import declarative_base +from sqlalchemy.orm import registry +from data.data_generator import StockDataGenerator from data.indicators.sma import calculate_sma +from database import Database from model.sma import Sma from model.ticker import Ticker +from model.tickerSymbol import TickerSymbol -database = Database() +class CommandHandler: + def __init__(self, database=None, stock_generator=None): + self.database = database + self.stock_generator = stock_generator + self.mapper_registry = registry() -def load_database(chunk_size=100000): - database.drop_tables() - database.create_tables() + def create_tables(self): + self.database.create_tables() - with open("./data/output/tickers.csv") as file: - try: - rows = [] - reader = csv.DictReader(file) - total = 0 - for row in reader: - ticker = Ticker( - date=datetime.strptime(row.get("date"), "%Y-%m-%d").date(), - symbol=row.get("symbol"), - open=row.get("open"), - high=row.get("high"), - low=row.get("low"), - close=row.get("close"), - volume=row.get("volume") - ) - rows.append(ticker) - if len(rows) >= chunk_size: + def generate_raw_daily_tickers(self, output_file): + self.stock_generator.generate_raw_data(output_file) + + + def load_raw_daily_tickers(self, path, chunk_size): + self._drop_table("ticker_rebuild") + copy_model = self._copy_existing_table_structure("ticker", "ticker_rebuild") + self._load_raw_ticker_data(copy_model, path, chunk_size) + self._rename_table("ticker", "ticker_backup") + self._rename_table("ticker_rebuild", "ticker") + + + def create_resource_sma(self): + self._drop_table("sma_rebuild") + copy_model = self._copy_existing_table_structure( "sma", "sma_rebuild") + self._generate_sma(copy_model) + self._rename_table("sma", "sma_backup") + self._rename_table("sma_rebuild", "sma") + + + def create_resource_ticker_symbols(self): + self._drop_table("ticker_symbols_rebuild") + copy_model = self._copy_existing_table_structure("ticker_symbol", "ticker_symbols_rebuild") + self._generate_ticker_symbols(copy_model) + self._rename_table("ticker_symbol", "ticker_symbol_backup") + self._rename_table("ticker_symbols_rebuild", "ticker_symbol") + + + def _copy_existing_table_structure(self, table_name, table_copy_name): + self.database.copy_table_structure(table_name, table_copy_name) + + metadata = MetaData() + metadata.reflect(bind=self.database.engine) + + copy_table = Table(table_copy_name, metadata, autoload_with=database.engine) + TempBase = declarative_base(metadata=metadata) + + + # Create a temporary table on the fly to load new data into. + class TempModel(TempBase): + __table__ = copy_table + + return TempModel + + + def _rename_table(self, current_name, new_name): + self.database.rename_table(current_name, new_name) + + + def _drop_table(self, tablename): + self.database.drop_table(tablename) + + + def _load_raw_ticker_data(self, model, path, chunk_size): + with open(path, "r") as ticker_csv: + try: + rows = [] + reader = csv.DictReader(ticker_csv) + total = 0 + for row in reader: + ticker = model( + date=datetime.strptime(row.get("date"), "%Y-%m-%d").date(), + symbol=row.get("symbol"), + open=row.get("open"), + high=row.get("high"), + low=row.get("low"), + close=row.get("close"), + volume=row.get("volume") + ) + rows.append(ticker) + + if len(rows) >= chunk_size: + database.session.bulk_save_objects(rows) + database.session.commit() + total += chunk_size + print(f"{total} rows inserted") + rows = [] + + # Remaining rows. + if rows: database.session.bulk_save_objects(rows) database.session.commit() - total += chunk_size - print(f"{total} rows inserted") - rows = [] - - # Insert any remaining rows. - if rows: - database.session.bulk_save_objects(rows) - database.session.commit() - total += len(rows) - print(f"{total} inserted") - except Exception as e: - import pprint - pprint.pprint(e) + total += len(rows) + print(f"{total} inserted") + except Exception as e: + print(e) -def create_sma(): - from sqlalchemy import distinct + def _generate_sma(self, model): + sma_periods = [5, 10, 20, 50, 100, 200] + symbols = self.database.session.query(distinct(Ticker.symbol)).all() + symbols = [symbol[0] for symbol in symbols] + count = 0 + for symbol in symbols: + try: + daily_ticker_data = self.database.session.query(Ticker).filter_by(symbol=symbol).all() + daily_ticker_data = [ticker.to_dict() for ticker in daily_ticker_data] + + smas = {} + for period in sma_periods: + data = copy.deepcopy(daily_ticker_data) + sma_data = calculate_sma(data, period) + smas[period] = sma_data + + for i in range(len(daily_ticker_data)): + row = model( + date=datetime.strptime(daily_ticker_data[i]["date"], "%Y-%m-%d"), + symbol=symbol, + close=daily_ticker_data[i]["close"], + sma_5=smas[5][i]["SMA"], + sma_10=smas[10][i]["SMA"], + sma_20=smas[20][i]["SMA"], + sma_50=smas[50][i]["SMA"], + sma_100=smas[100][i]["SMA"], + sma_200=smas[200][i]["SMA"] + ) + self.database.session.add(row) + self.database.session.commit() + count += 1 + print(f"finished {symbol}: {count} of {len(symbols)}: time: {datetime.now()}") + except Exception as e: + print(e) + self.database.session.rollback() - database.drop_table("sma") - database.create_tables() - sma_periods = [5, 10, 20, 50, 100, 200] - symbols = database.session.query(distinct(Ticker.symbol)).all() - count = 0 - for symbol in symbols: - try: - daily_ticker_data = database.session.query(Ticker).filter_by(symbol=symbol[0]).all() - daily_ticker_data = [ticker.to_dict() for ticker in daily_ticker_data] - - smas = {} - for period in sma_periods: - data = copy.deepcopy(daily_ticker_data) - sma_data = calculate_sma(data, period) - smas[period] = sma_data - - for i in range(len(daily_ticker_data)): - row = Sma( - date=daily_ticker_data[i]['date'], - symbol=symbol[0], - sma_5=smas[5][i]['SMA'], - sma_10=smas[10][i]['SMA'], - sma_20=smas[20][i]['SMA'], - sma_50=smas[50][i]['SMA'], - sma_100=smas[100][i]['SMA'], - sma_200=smas[200][i]['SMA'] - ) - database.session.add(row) - database.session.commit() - count += 1 - print(f"finished {symbol[0]}: {count} of {len(symbols)}: time: {datetime.now()}") - except Exception as e: - print(e) - database.session.rollback() + def _generate_ticker_symbols(self, model): + symbols = self.database.session.query(distinct(Ticker.symbol)).all() + symbols = [symbol[0] for symbol in symbols] + for symbol in symbols: + ticker = model(symbol=symbol) + self.database.session.add(ticker) + self.database.session.commit() + + +def create_parser(): + parser = argparse.ArgumentParser(description="Manager Script For Setup") + subparsers = parser.add_subparsers(dest="command", help="Available commands") + + # Subparser for database command. + parser_database = subparsers.add_parser("database", help="Database management commands") + database_subparsers = parser_database.add_subparsers(dest="action", help="Database actions") + + # Create action. + database_subparsers.add_parser("create", help="Create the database structure") + + # Load action with chunk_size option. + parser_load = database_subparsers.add_parser("load", help="Load Raw Ticker Data into the database") + parser_load.add_argument("file", type=str, default="./data/output/tickers.csv", help="Location of the file") + parser_load.add_argument("--chunk_size", type=int, default=100000, help="Chunk size for loading data") + + # Subparser for resource command. + parser_resource = subparsers.add_parser("resource", help="Resource management commands") + resource_subparsers = parser_resource.add_subparsers(dest="action", help="Resource actions") + resource_subparsers.add_parser("create_sma", help="Create SMA data") + resource_subparsers.add_parser("create_ticker_symbols", help="Get unique tickers") + + # Subparser for generate command. + parser_generate = subparsers.add_parser("generate", help="Generate the Raw Ticker CSV") + parser_generate.add_argument("--output", type=str, default="./data/output/tickers.csv", help="Optional output file path") + + return parser if __name__ == "__main__": - create_sma() \ No newline at end of file + parser = create_parser() + args = parser.parse_args() + + database = Database(database_url="sqlite:///finance.db") + data_generator = StockDataGenerator() + handler = CommandHandler( + database=database, + stock_generator=data_generator + ) + + if args.command == "database": + if args.action == "create": + handler.create_tables() + elif args.action == "load": + handler.load_raw_daily_tickers(args.file, args.chunk_size) + elif args.command == "resource": + if args.action == "create_sma": + handler.create_resource_sma() + elif args.action == "create_ticker_symbols": + handler.create_resource_ticker_symbols() + elif args.command == "generate": + handler.generate_raw_daily_tickers(args.output) + else: + parser.print_help() \ No newline at end of file diff --git a/model/sma.py b/model/sma.py index 1f0ec0f..a4dd456 100644 --- a/model/sma.py +++ b/model/sma.py @@ -14,9 +14,24 @@ class Sma(Base): id = Column(Integer, primary_key=True, autoincrement=True) date = Column(DateTime) symbol = Column(String(15), nullable=False) + close = Column(Float) sma_5 = Column(Float) sma_10 = Column(Float) sma_20 = Column(Float) sma_50 = Column(Float) sma_100 = Column(Float) - sma_200 = Column(Float) \ No newline at end of file + sma_200 = Column(Float) + + + def to_dict(self): + return { + "date": self.date.strftime("%Y-%m-%d"), + "symbol": self.symbol, + "close": self.close, + "sma_5": self.sma_5, + "sma_10": self.sma_10, + "sma_20": self.sma_20, + "sma_50": self.sma_50, + "sma_100": self.sma_100, + "sma_200": self.sma_200, + } \ No newline at end of file diff --git a/model/ticker.py b/model/ticker.py index f9eca14..11ea35d 100644 --- a/model/ticker.py +++ b/model/ticker.py @@ -23,7 +23,7 @@ class Ticker(Base): def to_dict(self): return { - "date": self.date, + "date": self.date.strftime("%Y-%m-%d"), "symbol": self.symbol, "open": self.open, "high": self.high, diff --git a/model/tickerSymbol.py b/model/tickerSymbol.py new file mode 100644 index 0000000..7815e52 --- /dev/null +++ b/model/tickerSymbol.py @@ -0,0 +1,21 @@ +from datetime import datetime +from sqlalchemy import Column +from sqlalchemy import DateTime +from sqlalchemy import Float +from sqlalchemy import Integer +from sqlalchemy import String +from model.base import Base + + +class TickerSymbol(Base): + __tablename__ = 'ticker_symbol' + + + id = Column(Integer, primary_key=True, autoincrement=True) + symbol = Column(String(15), nullable=False, unique=True) + + + def to_dict(self): + return { + "symbol": self.symbol, + } \ No newline at end of file diff --git a/rest_api.py b/rest_api.py new file mode 100644 index 0000000..54dcb8d --- /dev/null +++ b/rest_api.py @@ -0,0 +1,51 @@ +from flask import Flask +from database import Database + + +def create_app(config_object): + app = Flask(__name__) + app.config.from_object(config_object) + + app.json.sort_keys = False + app.url_map.strict_slashes = False + + database = Database() + database.init_app(app) + + configure_blueprints(app, database) + return app + + +def configure_blueprints(app, database): + from api.views.historical import HistoricalView + from api.views.sma import SmaView + from api.views.ticker import TickerView + from api.service.historical import HistoricalService + from api.service.sma import SmaService + from api.service.ticker import TickerService + + + historical_service = HistoricalService(database=database) + historical_view = HistoricalView.as_view("historical", service=historical_service) + historical_url = "/api/v1/historical/" + + sma_service = SmaService(database=database) + sma_view = SmaView.as_view("sma", service=sma_service) + sma_url = "/api/v1/indicators/sma/" + + tickers_service = TickerService(database=database) + tickers_view = TickerView.as_view("tickers", service=tickers_service) + tickers_url = "/api/v1/tickers" + + + # Register routes with HTTP methods. + app.add_url_rule(historical_url, view_func=historical_view, methods=["GET"]) + app.add_url_rule(sma_url, view_func=sma_view, methods=["GET"]) + app.add_url_rule(tickers_url, view_func=tickers_view, methods=["GET"]) + + +if __name__ == "__main__": + from api.config import ApiConfig + + app = create_app(ApiConfig) + app.run(host="localhost", port=3000, debug=True) \ No newline at end of file