216 lines
8.2 KiB
Python
216 lines
8.2 KiB
Python
import argparse
|
|
import csv
|
|
import copy
|
|
from datetime import datetime
|
|
from sqlalchemy import distinct
|
|
from sqlalchemy import MetaData
|
|
from sqlalchemy import Table
|
|
from sqlalchemy.orm import declarative_base
|
|
from sqlalchemy.orm import registry
|
|
from data.data_generator import StockDataGenerator
|
|
from data.indicators.sma import calculate_sma
|
|
from database import Database
|
|
from model.sma import Sma
|
|
from model.ticker import Ticker
|
|
from model.tickerSymbol import TickerSymbol
|
|
|
|
|
|
class CommandHandler:
|
|
def __init__(self, database=None, stock_generator=None):
|
|
self.database = database
|
|
self.stock_generator = stock_generator
|
|
self.mapper_registry = registry()
|
|
|
|
|
|
def create_tables(self):
|
|
self.database.create_tables()
|
|
|
|
|
|
def generate_raw_daily_tickers(self, output_file):
|
|
self.stock_generator.generate_raw_data(output_file)
|
|
|
|
|
|
def load_raw_daily_tickers(self, path, chunk_size):
|
|
self._drop_table("ticker_rebuild")
|
|
copy_model = self._copy_existing_table_structure("ticker", "ticker_rebuild")
|
|
self._load_raw_ticker_data(copy_model, path, chunk_size)
|
|
self._rename_table("ticker", "ticker_backup")
|
|
self._rename_table("ticker_rebuild", "ticker")
|
|
|
|
|
|
def create_resource_sma(self):
|
|
self._drop_table("sma_rebuild")
|
|
copy_model = self._copy_existing_table_structure( "sma", "sma_rebuild")
|
|
self._generate_sma(copy_model)
|
|
self._rename_table("sma", "sma_backup")
|
|
self._rename_table("sma_rebuild", "sma")
|
|
|
|
|
|
def create_resource_ticker_symbols(self):
|
|
self._drop_table("ticker_symbols_rebuild")
|
|
copy_model = self._copy_existing_table_structure("ticker_symbol", "ticker_symbols_rebuild")
|
|
self._generate_ticker_symbols(copy_model)
|
|
self._rename_table("ticker_symbol", "ticker_symbol_backup")
|
|
self._rename_table("ticker_symbols_rebuild", "ticker_symbol")
|
|
|
|
|
|
def _copy_existing_table_structure(self, table_name, table_copy_name):
|
|
self.database.copy_table_structure(table_name, table_copy_name)
|
|
|
|
metadata = MetaData()
|
|
metadata.reflect(bind=self.database.engine)
|
|
|
|
copy_table = Table(table_copy_name, metadata, autoload_with=database.engine)
|
|
TempBase = declarative_base(metadata=metadata)
|
|
|
|
|
|
# Create a temporary table on the fly to load new data into.
|
|
class TempModel(TempBase):
|
|
__table__ = copy_table
|
|
|
|
return TempModel
|
|
|
|
|
|
def _rename_table(self, current_name, new_name):
|
|
self.database.rename_table(current_name, new_name)
|
|
|
|
|
|
def _drop_table(self, tablename):
|
|
self.database.drop_table(tablename)
|
|
|
|
|
|
def _load_raw_ticker_data(self, model, path, chunk_size):
|
|
with open(path, "r") as ticker_csv:
|
|
try:
|
|
rows = []
|
|
reader = csv.DictReader(ticker_csv)
|
|
total = 0
|
|
for row in reader:
|
|
ticker = model(
|
|
date=datetime.strptime(row.get("date"), "%Y-%m-%d").date(),
|
|
symbol=row.get("symbol"),
|
|
open=row.get("open"),
|
|
high=row.get("high"),
|
|
low=row.get("low"),
|
|
close=row.get("close"),
|
|
volume=row.get("volume")
|
|
)
|
|
rows.append(ticker)
|
|
|
|
if len(rows) >= chunk_size:
|
|
database.session.bulk_save_objects(rows)
|
|
database.session.commit()
|
|
total += chunk_size
|
|
print(f"{total} rows inserted")
|
|
rows = []
|
|
|
|
# Remaining rows.
|
|
if rows:
|
|
database.session.bulk_save_objects(rows)
|
|
database.session.commit()
|
|
total += len(rows)
|
|
print(f"{total} inserted")
|
|
except Exception as e:
|
|
print(e)
|
|
|
|
|
|
def _generate_sma(self, model):
|
|
sma_periods = [5, 10, 20, 50, 100, 200]
|
|
symbols = self.database.session.query(distinct(Ticker.symbol)).all()
|
|
symbols = [symbol[0] for symbol in symbols]
|
|
count = 0
|
|
for symbol in symbols:
|
|
try:
|
|
daily_ticker_data = self.database.session.query(Ticker).filter_by(symbol=symbol).all()
|
|
daily_ticker_data = [ticker.to_dict() for ticker in daily_ticker_data]
|
|
|
|
smas = {}
|
|
for period in sma_periods:
|
|
data = copy.deepcopy(daily_ticker_data)
|
|
sma_data = calculate_sma(data, period)
|
|
smas[period] = sma_data
|
|
|
|
for i in range(len(daily_ticker_data)):
|
|
row = model(
|
|
date=datetime.strptime(daily_ticker_data[i]["date"], "%Y-%m-%d"),
|
|
symbol=symbol,
|
|
close=daily_ticker_data[i]["close"],
|
|
sma_5=smas[5][i]["SMA"],
|
|
sma_10=smas[10][i]["SMA"],
|
|
sma_20=smas[20][i]["SMA"],
|
|
sma_50=smas[50][i]["SMA"],
|
|
sma_100=smas[100][i]["SMA"],
|
|
sma_200=smas[200][i]["SMA"]
|
|
)
|
|
self.database.session.add(row)
|
|
self.database.session.commit()
|
|
count += 1
|
|
print(f"finished {symbol}: {count} of {len(symbols)}: time: {datetime.now()}")
|
|
except Exception as e:
|
|
print(e)
|
|
self.database.session.rollback()
|
|
|
|
|
|
def _generate_ticker_symbols(self, model):
|
|
symbols = self.database.session.query(distinct(Ticker.symbol)).all()
|
|
symbols = [symbol[0] for symbol in symbols]
|
|
for symbol in symbols:
|
|
ticker = model(symbol=symbol)
|
|
self.database.session.add(ticker)
|
|
self.database.session.commit()
|
|
|
|
|
|
def create_parser():
|
|
parser = argparse.ArgumentParser(description="Manager Script For Setup")
|
|
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
|
|
|
# Subparser for database command.
|
|
parser_database = subparsers.add_parser("database", help="Database management commands")
|
|
database_subparsers = parser_database.add_subparsers(dest="action", help="Database actions")
|
|
|
|
# Create action.
|
|
database_subparsers.add_parser("create", help="Create the database structure")
|
|
|
|
# Load action with chunk_size option.
|
|
parser_load = database_subparsers.add_parser("load", help="Load Raw Ticker Data into the database")
|
|
parser_load.add_argument("file", type=str, default="./data/output/tickers.csv", help="Location of the file")
|
|
parser_load.add_argument("--chunk_size", type=int, default=100000, help="Chunk size for loading data")
|
|
|
|
# Subparser for resource command.
|
|
parser_resource = subparsers.add_parser("resource", help="Resource management commands")
|
|
resource_subparsers = parser_resource.add_subparsers(dest="action", help="Resource actions")
|
|
resource_subparsers.add_parser("create_sma", help="Create SMA data")
|
|
resource_subparsers.add_parser("create_ticker_symbols", help="Get unique tickers")
|
|
|
|
# Subparser for generate command.
|
|
parser_generate = subparsers.add_parser("generate", help="Generate the Raw Ticker CSV")
|
|
parser_generate.add_argument("--output", type=str, default="./data/output/tickers.csv", help="Optional output file path")
|
|
|
|
return parser
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = create_parser()
|
|
args = parser.parse_args()
|
|
|
|
database = Database(database_url="sqlite:///finance.db")
|
|
data_generator = StockDataGenerator()
|
|
handler = CommandHandler(
|
|
database=database,
|
|
stock_generator=data_generator
|
|
)
|
|
|
|
if args.command == "database":
|
|
if args.action == "create":
|
|
handler.create_tables()
|
|
elif args.action == "load":
|
|
handler.load_raw_daily_tickers(args.file, args.chunk_size)
|
|
elif args.command == "resource":
|
|
if args.action == "create_sma":
|
|
handler.create_resource_sma()
|
|
elif args.action == "create_ticker_symbols":
|
|
handler.create_resource_ticker_symbols()
|
|
elif args.command == "generate":
|
|
handler.generate_raw_daily_tickers(args.output)
|
|
else:
|
|
parser.print_help() |