finance-api/manage.py

216 lines
8.2 KiB
Python

import argparse
import csv
import copy
from datetime import datetime
from sqlalchemy import distinct
from sqlalchemy import MetaData
from sqlalchemy import Table
from sqlalchemy.orm import declarative_base
from sqlalchemy.orm import registry
from data.data_generator import StockDataGenerator
from data.indicators.sma import calculate_sma
from database import Database
from model.sma import Sma
from model.ticker import Ticker
from model.tickerSymbol import TickerSymbol
class CommandHandler:
def __init__(self, database=None, stock_generator=None):
self.database = database
self.stock_generator = stock_generator
self.mapper_registry = registry()
def create_tables(self):
self.database.create_tables()
def generate_raw_daily_tickers(self, output_file):
self.stock_generator.generate_raw_data(output_file)
def load_raw_daily_tickers(self, path, chunk_size):
self._drop_table("ticker_rebuild")
copy_model = self._copy_existing_table_structure("ticker", "ticker_rebuild")
self._load_raw_ticker_data(copy_model, path, chunk_size)
self._rename_table("ticker", "ticker_backup")
self._rename_table("ticker_rebuild", "ticker")
def create_resource_sma(self):
self._drop_table("sma_rebuild")
copy_model = self._copy_existing_table_structure( "sma", "sma_rebuild")
self._generate_sma(copy_model)
self._rename_table("sma", "sma_backup")
self._rename_table("sma_rebuild", "sma")
def create_resource_ticker_symbols(self):
self._drop_table("ticker_symbols_rebuild")
copy_model = self._copy_existing_table_structure("ticker_symbol", "ticker_symbols_rebuild")
self._generate_ticker_symbols(copy_model)
self._rename_table("ticker_symbol", "ticker_symbol_backup")
self._rename_table("ticker_symbols_rebuild", "ticker_symbol")
def _copy_existing_table_structure(self, table_name, table_copy_name):
self.database.copy_table_structure(table_name, table_copy_name)
metadata = MetaData()
metadata.reflect(bind=self.database.engine)
copy_table = Table(table_copy_name, metadata, autoload_with=database.engine)
TempBase = declarative_base(metadata=metadata)
# Create a temporary table on the fly to load new data into.
class TempModel(TempBase):
__table__ = copy_table
return TempModel
def _rename_table(self, current_name, new_name):
self.database.rename_table(current_name, new_name)
def _drop_table(self, tablename):
self.database.drop_table(tablename)
def _load_raw_ticker_data(self, model, path, chunk_size):
with open(path, "r") as ticker_csv:
try:
rows = []
reader = csv.DictReader(ticker_csv)
total = 0
for row in reader:
ticker = model(
date=datetime.strptime(row.get("date"), "%Y-%m-%d").date(),
symbol=row.get("symbol"),
open=row.get("open"),
high=row.get("high"),
low=row.get("low"),
close=row.get("close"),
volume=row.get("volume")
)
rows.append(ticker)
if len(rows) >= chunk_size:
database.session.bulk_save_objects(rows)
database.session.commit()
total += chunk_size
print(f"{total} rows inserted")
rows = []
# Remaining rows.
if rows:
database.session.bulk_save_objects(rows)
database.session.commit()
total += len(rows)
print(f"{total} inserted")
except Exception as e:
print(e)
def _generate_sma(self, model):
sma_periods = [5, 10, 20, 50, 100, 200]
symbols = self.database.session.query(distinct(Ticker.symbol)).all()
symbols = [symbol[0] for symbol in symbols]
count = 0
for symbol in symbols:
try:
daily_ticker_data = self.database.session.query(Ticker).filter_by(symbol=symbol).all()
daily_ticker_data = [ticker.to_dict() for ticker in daily_ticker_data]
smas = {}
for period in sma_periods:
data = copy.deepcopy(daily_ticker_data)
sma_data = calculate_sma(data, period)
smas[period] = sma_data
for i in range(len(daily_ticker_data)):
row = model(
date=datetime.strptime(daily_ticker_data[i]["date"], "%Y-%m-%d"),
symbol=symbol,
close=daily_ticker_data[i]["close"],
sma_5=smas[5][i]["SMA"],
sma_10=smas[10][i]["SMA"],
sma_20=smas[20][i]["SMA"],
sma_50=smas[50][i]["SMA"],
sma_100=smas[100][i]["SMA"],
sma_200=smas[200][i]["SMA"]
)
self.database.session.add(row)
self.database.session.commit()
count += 1
print(f"finished {symbol}: {count} of {len(symbols)}: time: {datetime.now()}")
except Exception as e:
print(e)
self.database.session.rollback()
def _generate_ticker_symbols(self, model):
symbols = self.database.session.query(distinct(Ticker.symbol)).all()
symbols = [symbol[0] for symbol in symbols]
for symbol in symbols:
ticker = model(symbol=symbol)
self.database.session.add(ticker)
self.database.session.commit()
def create_parser():
parser = argparse.ArgumentParser(description="Manager Script For Setup")
subparsers = parser.add_subparsers(dest="command", help="Available commands")
# Subparser for database command.
parser_database = subparsers.add_parser("database", help="Database management commands")
database_subparsers = parser_database.add_subparsers(dest="action", help="Database actions")
# Create action.
database_subparsers.add_parser("create", help="Create the database structure")
# Load action with chunk_size option.
parser_load = database_subparsers.add_parser("load", help="Load Raw Ticker Data into the database")
parser_load.add_argument("file", type=str, default="./data/output/tickers.csv", help="Location of the file")
parser_load.add_argument("--chunk_size", type=int, default=100000, help="Chunk size for loading data")
# Subparser for resource command.
parser_resource = subparsers.add_parser("resource", help="Resource management commands")
resource_subparsers = parser_resource.add_subparsers(dest="action", help="Resource actions")
resource_subparsers.add_parser("create_sma", help="Create SMA data")
resource_subparsers.add_parser("create_ticker_symbols", help="Get unique tickers")
# Subparser for generate command.
parser_generate = subparsers.add_parser("generate", help="Generate the Raw Ticker CSV")
parser_generate.add_argument("--output", type=str, default="./data/output/tickers.csv", help="Optional output file path")
return parser
if __name__ == "__main__":
parser = create_parser()
args = parser.parse_args()
database = Database(database_url="sqlite:///finance.db")
data_generator = StockDataGenerator()
handler = CommandHandler(
database=database,
stock_generator=data_generator
)
if args.command == "database":
if args.action == "create":
handler.create_tables()
elif args.action == "load":
handler.load_raw_daily_tickers(args.file, args.chunk_size)
elif args.command == "resource":
if args.action == "create_sma":
handler.create_resource_sma()
elif args.action == "create_ticker_symbols":
handler.create_resource_ticker_symbols()
elif args.command == "generate":
handler.generate_raw_daily_tickers(args.output)
else:
parser.print_help()