79 lines
2.5 KiB
Python
79 lines
2.5 KiB
Python
from .base import BaseService
|
|
from model.sma import Sma
|
|
from api.errors import ResourceNotFound
|
|
from api.errors import InvalidSma
|
|
|
|
|
|
class SmaService(BaseService):
|
|
def get_sma(self, ticker, params):
|
|
# Validate SMA periods and get corresponding columns
|
|
sma_periods = params.get("period")
|
|
sma_columns = self._validate_and_get_columns(sma_periods)
|
|
|
|
from_date = params.get("from_date", "2000-01-01")
|
|
to_date = params.get("to_date", "2001-01-01")
|
|
|
|
dates = {"from_date": from_date, "to_date": to_date}
|
|
# Query the database for the necessary columns
|
|
data = self._query_sma_data(ticker, sma_columns, dates)
|
|
|
|
if not data:
|
|
raise ResourceNotFound
|
|
|
|
data = self._format_results(data, sma_columns)
|
|
|
|
result = {
|
|
"ticker": ticker,
|
|
"data": data
|
|
}
|
|
return result
|
|
|
|
|
|
def _validate_and_get_columns(self, sma_periods):
|
|
valid_sma_columns = {
|
|
"5": Sma.sma_5,
|
|
"10": Sma.sma_10,
|
|
"20": Sma.sma_20,
|
|
"50": Sma.sma_50,
|
|
"100": Sma.sma_100,
|
|
"200": Sma.sma_200,
|
|
}
|
|
|
|
sma_columns = []
|
|
for period in sma_periods.split(","):
|
|
if period not in valid_sma_columns.keys():
|
|
raise InvalidSma
|
|
sma_columns.append(valid_sma_columns[period])
|
|
return sma_columns
|
|
|
|
|
|
def _query_sma_data(self, ticker, sma_columns, dates):
|
|
# Start with the base query selecting the date and symbol
|
|
query = self.database.session.query(Sma.symbol, Sma.date, Sma.close)
|
|
|
|
# Add the SMA columns dynamically
|
|
query = query.add_columns(*sma_columns)
|
|
|
|
# Filter by the ticker symbol
|
|
query = query.filter(Sma.symbol == ticker)
|
|
query = query.filter(Sma.date >= dates.get("from_date"))
|
|
query = query.filter(Sma.date <= dates.get("to_date"))
|
|
|
|
# Execute the query and return the results
|
|
return query.all()
|
|
|
|
|
|
def _format_results(self, data, sma_columns):
|
|
# Format the results into a list of dictionaries
|
|
result = []
|
|
for row in data:
|
|
row_dict = {
|
|
"date": row.date.strftime("%Y-%m-%d"),
|
|
"close": row.close,
|
|
}
|
|
for i, column in enumerate(sma_columns):
|
|
# +2 because the first two columns are date and symbol
|
|
sma = row[i + 3]
|
|
row_dict[column.key] = sma
|
|
result.append(row_dict)
|
|
return result |