From f351633a055a16b558c226b935dcc75cf5734e43 Mon Sep 17 00:00:00 2001 From: thecodebranch Date: Mon, 1 Jul 2024 11:38:33 -0600 Subject: [PATCH] Add sma indicator --- data/indicators/sma.py | 21 ++++++++++++++++++++ database/__init__.py | 12 ++++++++++- manage.py | 45 +++++++++++++++++++++++++++++++++++++++++- model/sma.py | 22 +++++++++++++++++++++ model/ticker.py | 14 ++++++++++++- 5 files changed, 111 insertions(+), 3 deletions(-) create mode 100644 data/indicators/sma.py create mode 100644 model/sma.py diff --git a/data/indicators/sma.py b/data/indicators/sma.py new file mode 100644 index 0000000..528e1d2 --- /dev/null +++ b/data/indicators/sma.py @@ -0,0 +1,21 @@ +def calculate_sma(data, period): + """ + Calculate the Simple Moving Average (SMA) for the given data. + :param data: A list of dictionaries containing 'date' and 'close' prices. + :param period: period used to calculate sma. + :return: A list of dictionaries with an additional 'SMA' key. + """ + close_prices = [entry['close'] for entry in data] + sma_values = [] + + for i in range(len(close_prices)): + if i + 1 < period: + sma_values.append(None) + else: + sma = sum(close_prices[i + 1 - period:i + 1]) / period + sma_values.append(sma) + + for i in range(len(data)): + data[i]['SMA'] = sma_values[i] + + return data \ No newline at end of file diff --git a/database/__init__.py b/database/__init__.py index badebd3..f97baf5 100644 --- a/database/__init__.py +++ b/database/__init__.py @@ -2,6 +2,8 @@ from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import scoped_session from model.base import Base +from sqlalchemy import MetaData, Table + class Database: @@ -15,4 +17,12 @@ class Database: def drop_tables(self): - Base.metadata.drop_all(self.engine) \ No newline at end of file + Base.metadata.drop_all(self.engine) + + + def drop_table(self, table): + meta = MetaData() + meta.reflect(bind=self.engine) + if table in meta.tables: + table = Table(table, meta, autoload_with=self.engine) + table.drop(self.engine, checkfirst=True) \ No newline at end of file diff --git a/manage.py b/manage.py index f97aee3..7a96f9f 100644 --- a/manage.py +++ b/manage.py @@ -1,6 +1,9 @@ import csv +import copy from datetime import datetime from database import Database +from data.indicators.sma import calculate_sma +from model.sma import Sma from model.ticker import Ticker @@ -46,5 +49,45 @@ def load_database(chunk_size=100000): pprint.pprint(e) +def create_sma(): + from sqlalchemy import distinct + + database.drop_table("sma") + database.create_tables() + + sma_periods = [5, 10, 20, 50, 100, 200] + symbols = database.session.query(distinct(Ticker.symbol)).all() + count = 0 + for symbol in symbols: + try: + daily_ticker_data = database.session.query(Ticker).filter_by(symbol=symbol[0]).all() + daily_ticker_data = [ticker.to_dict() for ticker in daily_ticker_data] + + smas = {} + for period in sma_periods: + data = copy.deepcopy(daily_ticker_data) + sma_data = calculate_sma(data, period) + smas[period] = sma_data + + for i in range(len(daily_ticker_data)): + row = Sma( + date=daily_ticker_data[i]['date'], + symbol=symbol[0], + sma_5=smas[5][i]['SMA'], + sma_10=smas[10][i]['SMA'], + sma_20=smas[20][i]['SMA'], + sma_50=smas[50][i]['SMA'], + sma_100=smas[100][i]['SMA'], + sma_200=smas[200][i]['SMA'] + ) + database.session.add(row) + database.session.commit() + count += 1 + print(f"finished {symbol[0]}: {count} of {len(symbols)}") + except Exception as e: + print(e) + database.session.rollback() + + if __name__ == "__main__": - load_database() \ No newline at end of file + create_sma() \ No newline at end of file diff --git a/model/sma.py b/model/sma.py new file mode 100644 index 0000000..1f0ec0f --- /dev/null +++ b/model/sma.py @@ -0,0 +1,22 @@ +from datetime import datetime +from sqlalchemy import Column +from sqlalchemy import DateTime +from sqlalchemy import Float +from sqlalchemy import Integer +from sqlalchemy import String +from model.base import Base + + +class Sma(Base): + __tablename__ = 'sma' + + + id = Column(Integer, primary_key=True, autoincrement=True) + date = Column(DateTime) + symbol = Column(String(15), nullable=False) + sma_5 = Column(Float) + sma_10 = Column(Float) + sma_20 = Column(Float) + sma_50 = Column(Float) + sma_100 = Column(Float) + sma_200 = Column(Float) \ No newline at end of file diff --git a/model/ticker.py b/model/ticker.py index 7cb9c6d..f9eca14 100644 --- a/model/ticker.py +++ b/model/ticker.py @@ -18,4 +18,16 @@ class Ticker(Base): high = Column(Float) low = Column(Float) close = Column(Float) - volume = Column(Integer) \ No newline at end of file + volume = Column(Integer) + + + def to_dict(self): + return { + "date": self.date, + "symbol": self.symbol, + "open": self.open, + "high": self.high, + "low": self.low, + "close": self.close, + "volume": self.volume, + } \ No newline at end of file