SZI2019SmieciarzWmi/VowpalWabbit/vowpal_utils.py

178 lines
5.8 KiB
Python
Raw Normal View History

2019-06-09 08:25:18 +02:00
import re, os
from config import MAP_NAME, GRID_WIDTH, GRID_HEIGHT, GC_X, GC_Y
from VowpalWabbit.VowpalWrapper import wrapper
#const
2019-06-09 08:25:18 +02:00
RADIUS = 2
##
COORDINATES_LIST = []
MOVES_LIST = []
with open( MAP_NAME, 'r' ) as map:
MAP_CONTENT = map.readlines()[2:]
MAP_CONTENT = [list(row.strip().replace(" ","")) for row in MAP_CONTENT]
moves_mapping = {
"pick_garbage": 1,
"right": 2,
"left": 3,
"up": 4,
"down": 5
}
predictions_mapping = {
1 : "pick_garbage",
2 : "right",
3 : "left",
4 : "up",
5 : "down"
}
environment_mapping = {
"E":0,
"R":1,
"H":2,
"V":3,
"Y":4,
"B":5,
"G":6
}
def parse_list(whole_result,current_x,current_y):
global COORDINATES_LIST, MOVES_LIST
COORDINATES_LIST = whole_result.copy()
moves = []
2019-06-09 11:33:19 +02:00
primary_x = current_x
primary_y = current_y
2019-06-05 18:21:27 +02:00
#print("x,y",current_x,current_y,"list",whole_result)
parser = {'[0,1]':"down",'[0,-1]':"up",'[1,0]':"right",'[-1,0]':"left"}
for x in range(len(whole_result)):
if whole_result[x]=="pick_garbage":
moves.append(whole_result[x])
else:
x_subtraction = whole_result[x][0] - current_x
y_subtraction = whole_result[x][1] - current_y
current_x = whole_result[x][0]
current_y = whole_result[x][1]
moves.append(parser[f"[{x_subtraction},{y_subtraction}]"])
2019-06-05 18:21:27 +02:00
#print(moves)
MOVES_LIST = moves.copy()
2019-06-09 11:33:19 +02:00
generate_input([primary_x,primary_y])
return moves
2019-06-09 11:33:19 +02:00
def generate_input(current_position):
i = 0 #we'll use it to map coords to moves
input_file_content = []
2019-06-09 11:33:19 +02:00
for position in COORDINATES_LIST:
coords = check_position(current_position, i) #set valid gc position
#vowpal config goes here
2019-06-09 08:25:18 +02:00
importance = 1.0
label = moves_mapping[MOVES_LIST[i]]
area,importance = get_gc_area(coords, RADIUS)
if importance==None:
2019-06-09 00:45:24 +02:00
if MOVES_LIST[i] == "up" or MOVES_LIST[i] == "down":
importance = 5
else:
importance = 1
if MOVES_LIST[i] == "pick_garbage":
importance = 100
2019-06-09 08:25:18 +02:00
input_line = str(label) + " " + str(importance) + " | "
2019-06-09 11:33:19 +02:00
print ( predictions_mapping[label]+" " +str(importance) + " |", end = " " )
current_position = position
for a in area:
input_line += a + " "
2019-06-09 11:33:19 +02:00
print( a ,end =" ")
i += 1
2019-06-09 11:33:19 +02:00
print()
input_file_content.append(input_line)
#save to file
2019-06-05 17:41:58 +02:00
tag = re.findall("(map_[0-9]+|map[0-9]+_auto)", MAP_NAME)[0]
2019-06-05 18:21:27 +02:00
filename = "./VowpalWabbit/VowpalInputData/input_" + str(tag) + ".txt"
2019-06-09 08:25:18 +02:00
input_file = open(filename,"a+")
for line in input_file_content:
input_file.write(line+"\n")
input_file.close()
def pass_input(position):
2019-06-09 08:25:18 +02:00
input_line = "1 | "
area = get_gc_area(position, RADIUS)
for a in area:
input_line += a + " "
print(input_line)
#save to file
filename = "./VowpalWabbit/VowpalDataCache/constant_input.txt"
input_file = open(filename,"a+")
input_file.write(input_line)
input_file.close()
return filename
def get_gc_area(position, radius):
area = []
2019-06-05 19:08:08 +02:00
upper_right_coord = [position[0] - radius, position[1] - radius]
importance = None
2019-06-05 19:08:08 +02:00
for x in range(max(0, position[0] - radius), min(position[0] + radius + 1, GRID_WIDTH)): #prevents going abroad
for y in range(max(0, position[1] - radius), min(position[1] + radius + 1, GRID_HEIGHT)):
if([x,y] == position): #we dont need gc data here
continue
2019-06-09 00:15:21 +02:00
if MAP_CONTENT[y][x] == 'H':
importance = 90
2019-06-09 00:15:21 +02:00
elif MAP_CONTENT[y][x] == 'B' or MAP_CONTENT[y][x] == 'Y' or MAP_CONTENT[y][x] == 'G':
2019-06-09 00:36:21 +02:00
importance = 65
area.append("F"+str(x - upper_right_coord[0])+str(y - upper_right_coord[1])+":"+str(environment_mapping[MAP_CONTENT[y][x]]))
return area,importance
def check_position(position, i):
if(type(position) is list): #if position valid, return it
return position
elif(position == "pick_garbage"): #if invalid, look for recent coords. if not found, return initial coords
for j in range(i-1,-1,-1):
if(type(COORDINATES_LIST[j]) is list):
return COORDINATES_LIST[j]
return [GC_X, GC_Y]
else: #in case sh t happened
print("An error has ocurred while processing GC position.")
def get_predicted_move(position):
input_filename = pass_input(position)
output_filename = "./VowpalWabbit/VowpalDataCache/constant_output.txt"
2019-06-09 08:25:18 +02:00
wrapper.wrap_ex("vw -i ./VowpalWabbit/VowpalModels/teraz.model -t "+input_filename+" -p "+output_filename+" ")
with open( output_filename, 'r' ) as fout:
prediction = float(list(fout.readline().split())[0])
move = make_move_from_prediction(prediction)
print(position, prediction, move)
if(move == "pick_garbage"):
new_position = move
else:
axis = 0
if(move in ["up", "down"]):
axis = 1
direction = 1
if(move in ["up", "left"]):
direction = -1
new_position = position.copy()
new_position[axis] += direction
if(axis == 1 and (new_position[axis] < 0 or new_position[axis] >= GRID_HEIGHT)):
new_position = position.copy()
print("VIOLATED GRID HEIGHT")
if(axis == 0 and (new_position[axis] < 0 or new_position[axis] >= GRID_WIDTH)):
new_position = position.copy()
print("VIOLATED GRID WIDTH")
return move, new_position
def make_move_from_prediction(prediction):
if(prediction > 4.5):
move = predictions_mapping[5]
elif(prediction > 3.5):
move = predictions_mapping[4]
elif(prediction > 2.5):
move = predictions_mapping[3]
elif(prediction > 1.5):
move = predictions_mapping[2]
else:
move = predictions_mapping[1]
return move