processing_for_model2.py
2.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Apr 13 16:57:11 2017
@author: red-sky
"""
import sys
import numpy as np
import pickle
import pandas as pd
def main(VectorsPath, EventPath, StockPricePath, days):
with open(VectorsPath, "rb") as H:
Vec = pickle.load(H)
Vectors = np.array([list(b[0]) for a, b in Vec.values()])
# Vectors = np.load(VectorsPath)
with open(EventPath, "r") as H:
F = np.array([a.split("\t")[0:4] for a in H.read().splitlines()])
D = {}
for date, vec in zip(F[:, 0], Vectors):
if date[:10] in D:
D[date[:10]].append(vec)
else:
D[date[:10]] = [vec]
D2 = {}
for date in sorted(D.keys()):
D2[date] = np.mean(D[date], 0)
Dates = np.array(sorted(D2.keys()))
SampleIndex = [list(range(i-days, i)) for i in range(5, len(Dates))]
DataX = []
DateX = []
for listIndex in SampleIndex:
DataX.append([D2[date] for date in Dates[listIndex]])
DateX.append(Dates[listIndex[-1]])
Df = pd.read_csv(StockPricePath)
LabelY = []
DataX_yesData = []
for i, date in enumerate(DateX):
retu = list(Df.loc[Df["Date"] == date]["ReturnOpen"])
print(retu)
if len(retu) > 0:
retu = float(retu[0])*100
if retu > 0:
LabelY.append([1, 0])
if retu < -0:
LabelY.append([0, 1])
if retu <= 0 and retu >= -0:
LabelY.append([0, 1])
DataX_yesData.append(list(DataX[i]))
print(date)
# else:
dataX = np.array(DataX_yesData)
dataY = np.array(LabelY)
print("DataX:", dataX.shape)
print("DataY:", dataY.shape, np.sum(dataY, 0) / np.sum(dataY))
return (dataX, dataY)
if __name__ == "__main__":
VectorsPath = sys.argv[1]
EventPath = sys.argv[2]
StockPricePath = sys.argv[3]
days = int(sys.argv[5])
DataX, LabelY = main(VectorsPath, EventPath, StockPricePath, days)
DataPath = sys.argv[4]
np.save(arr=DataX, file=DataPath+"/DailyVector" + sys.argv[5] + ".npy")
np.save(arr=LabelY, file=DataPath+"/DailyReturn" + sys.argv[5] + ".npy")