import argparse import csv import copy from datetime import datetime from sqlalchemy import distinct from sqlalchemy import MetaData from sqlalchemy import Table from sqlalchemy.orm import declarative_base from sqlalchemy.orm import registry from data.data_generator import StockDataGenerator from data.indicators.sma import calculate_sma from database import Database from model.sma import Sma from model.ticker import Ticker from model.tickerSymbol import TickerSymbol class CommandHandler: def __init__(self, database=None, stock_generator=None): self.database = database self.stock_generator = stock_generator self.mapper_registry = registry() def create_tables(self): self.database.create_tables() def generate_raw_daily_tickers(self, output_file): self.stock_generator.generate_raw_data(output_file) def load_raw_daily_tickers(self, path, chunk_size): self._drop_table("ticker_rebuild") copy_model = self._copy_existing_table_structure("ticker", "ticker_rebuild") self._load_raw_ticker_data(copy_model, path, chunk_size) self._rename_table("ticker", "ticker_backup") self._rename_table("ticker_rebuild", "ticker") def create_resource_sma(self): self._drop_table("sma_rebuild") copy_model = self._copy_existing_table_structure( "sma", "sma_rebuild") self._generate_sma(copy_model) self._rename_table("sma", "sma_backup") self._rename_table("sma_rebuild", "sma") def create_resource_ticker_symbols(self): self._drop_table("ticker_symbols_rebuild") copy_model = self._copy_existing_table_structure("ticker_symbol", "ticker_symbols_rebuild") self._generate_ticker_symbols(copy_model) self._rename_table("ticker_symbol", "ticker_symbol_backup") self._rename_table("ticker_symbols_rebuild", "ticker_symbol") def _copy_existing_table_structure(self, table_name, table_copy_name): self.database.copy_table_structure(table_name, table_copy_name) metadata = MetaData() metadata.reflect(bind=self.database.engine) copy_table = Table(table_copy_name, metadata, autoload_with=database.engine) TempBase = declarative_base(metadata=metadata) # Create a temporary table on the fly to load new data into. class TempModel(TempBase): __table__ = copy_table return TempModel def _rename_table(self, current_name, new_name): self.database.rename_table(current_name, new_name) def _drop_table(self, tablename): self.database.drop_table(tablename) def _load_raw_ticker_data(self, model, path, chunk_size): with open(path, "r") as ticker_csv: try: rows = [] reader = csv.DictReader(ticker_csv) total = 0 for row in reader: ticker = model( date=datetime.strptime(row.get("date"), "%Y-%m-%d").date(), symbol=row.get("symbol"), open=row.get("open"), high=row.get("high"), low=row.get("low"), close=row.get("close"), volume=row.get("volume") ) rows.append(ticker) if len(rows) >= chunk_size: database.session.bulk_save_objects(rows) database.session.commit() total += chunk_size print(f"{total} rows inserted") rows = [] # Remaining rows. if rows: database.session.bulk_save_objects(rows) database.session.commit() total += len(rows) print(f"{total} inserted") except Exception as e: print(e) def _generate_sma(self, model): sma_periods = [5, 10, 20, 50, 100, 200] symbols = self.database.session.query(distinct(Ticker.symbol)).all() symbols = [symbol[0] for symbol in symbols] count = 0 for symbol in symbols: try: daily_ticker_data = self.database.session.query(Ticker).filter_by(symbol=symbol).all() daily_ticker_data = [ticker.to_dict() for ticker in daily_ticker_data] smas = {} for period in sma_periods: data = copy.deepcopy(daily_ticker_data) sma_data = calculate_sma(data, period) smas[period] = sma_data for i in range(len(daily_ticker_data)): row = model( date=datetime.strptime(daily_ticker_data[i]["date"], "%Y-%m-%d"), symbol=symbol, close=daily_ticker_data[i]["close"], sma_5=smas[5][i]["SMA"], sma_10=smas[10][i]["SMA"], sma_20=smas[20][i]["SMA"], sma_50=smas[50][i]["SMA"], sma_100=smas[100][i]["SMA"], sma_200=smas[200][i]["SMA"] ) self.database.session.add(row) self.database.session.commit() count += 1 print(f"finished {symbol}: {count} of {len(symbols)}: time: {datetime.now()}") except Exception as e: print(e) self.database.session.rollback() def _generate_ticker_symbols(self, model): symbols = self.database.session.query(distinct(Ticker.symbol)).all() symbols = [symbol[0] for symbol in symbols] for symbol in symbols: ticker = model(symbol=symbol) self.database.session.add(ticker) self.database.session.commit() def create_parser(): parser = argparse.ArgumentParser(description="Manager Script For Setup") subparsers = parser.add_subparsers(dest="command", help="Available commands") # Subparser for database command. parser_database = subparsers.add_parser("database", help="Database management commands") database_subparsers = parser_database.add_subparsers(dest="action", help="Database actions") # Create action. database_subparsers.add_parser("create", help="Create the database structure") # Load action with chunk_size option. parser_load = database_subparsers.add_parser("load", help="Load Raw Ticker Data into the database") parser_load.add_argument("file", type=str, default="./data/output/tickers.csv", help="Location of the file") parser_load.add_argument("--chunk_size", type=int, default=100000, help="Chunk size for loading data") # Subparser for resource command. parser_resource = subparsers.add_parser("resource", help="Resource management commands") resource_subparsers = parser_resource.add_subparsers(dest="action", help="Resource actions") resource_subparsers.add_parser("create_sma", help="Create SMA data") resource_subparsers.add_parser("create_ticker_symbols", help="Get unique tickers") # Subparser for generate command. parser_generate = subparsers.add_parser("generate", help="Generate the Raw Ticker CSV") parser_generate.add_argument("--output", type=str, default="./data/output/tickers.csv", help="Optional output file path") return parser if __name__ == "__main__": parser = create_parser() args = parser.parse_args() database = Database(database_url="sqlite:///finance.db") data_generator = StockDataGenerator() handler = CommandHandler( database=database, stock_generator=data_generator ) if args.command == "database": if args.action == "create": handler.create_tables() elif args.action == "load": handler.load_raw_daily_tickers(args.file, args.chunk_size) elif args.command == "resource": if args.action == "create_sma": handler.create_resource_sma() elif args.action == "create_ticker_symbols": handler.create_resource_ticker_symbols() elif args.command == "generate": handler.generate_raw_daily_tickers(args.output) else: parser.print_help()