Skip to content
Snippets Groups Projects
Commit c0a62c47 authored by Matthias König's avatar Matthias König
Browse files

policy ctrl

parent 9f5ddd5b
No related branches found
No related tags found
No related merge requests found
......@@ -4,6 +4,7 @@ sys.path.append("..")
sys.path.append("../rl")
from multiprocessing import Process, Value
from threading import Thread
import argparse
import time
import uuid
from tornado import websocket
......@@ -14,7 +15,7 @@ import cv2
import numpy as np
import base64
import serial
from policy_controller import AttentionPolicyController
from policy_controller import AttentionPolicyController, load_policy_control
from camera import FisheyeCamera, AttentionCamera
from serial_com import SerialCom
from rl.ppo_continous import Agent
......@@ -32,8 +33,7 @@ steer_val = Value("f", 0)
stop_val = Value("b", False)
serial = SerialCom("config.yml")
#cam = FisheyeCamera("/dev/video0")
cam = AttentionCamera("/dev/video0", model_path, model)
# cam = FisheyeCamera("/dev/video0")
class PolicyController:
......@@ -58,8 +58,10 @@ class PolicyController:
obs = torch.Tensor(obs)
action = self.agent.get_action(obs)
action = action.numpy()
self.speed.value = alpha*self.speed.value+(1-alpha)*np.clip(action[0], -0.2, 0.7)
self.steer.value = np.clip(action[1]*1.3, -1.0, 1.0)
self.speed.value = alpha * self.speed.value + (1 - alpha) * np.clip(
action[0], -0.2, 0.7
)
self.steer.value = np.clip(action[1] * 1.3, -1.0, 1.0)
print("speed action:", action[0])
if self.stop.value:
self.speed.value = 0
......@@ -125,26 +127,27 @@ class CommandHandler(websocket.WebSocketHandler):
timeout = 0
client = None
def initialize(self, speed, steer, stop, cam):
def initialize(self, speed, steer, stop, cam, policy_control):
self.speed = speed
self.steer = steer
self.stop = stop
self.cam = cam
#self.policy_control = PolicyController(
# self.policy_control = PolicyController(
# "../trained_models/rl/ppo/ppo_steps512_8obs/best_agent.pth",
# self.cam,
# self.speed,
# self.steer,
# self.stop,
#)
self.policy_control = AttentionPolicyController(
model_path,
model,
self.cam,
self.speed,
self.steer,
self.stop,
)
# )
self.policy_control = policy_control
# self.policy_control = AttentionPolicyController(
# model_path,
# model,
# self.cam,
# self.speed,
# self.steer,
# self.stop,
# )
def open(self):
self.id = uuid.uuid4()
......@@ -225,14 +228,20 @@ class VideoHandler(websocket.WebSocketHandler):
self.write_message(message)
def make_app():
def make_app(cam, policy_control):
return tornado.web.Application(
[
(r"/", MainHandler),
(
r"/cmd_ws",
CommandHandler,
dict(speed=speed_val, steer=steer_val, stop=stop_val, cam=cam),
dict(
speed=speed_val,
steer=steer_val,
stop=stop_val,
cam=cam,
policy_control=policy_control,
),
),
# (r"/drive", DriveHandler, dict(speed=speed_val, steer=steer_val)),
(r"/video_ws", VideoHandler, dict(cam=cam)),
......@@ -251,7 +260,29 @@ def control_loop(speed, steer):
time.sleep(max(1.0 / frame_rate - (time.time() - start), 0))
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model-dir", help="Directory of model.")
parser.add_argument(
"--model-filename",
default="best_model.pth",
help="File name of the model to evaluate.",
)
parser.add_argument(
"--cam",
default="/dev/video0",
help="Video input",
)
parser.add_argument("--port", help="Server Port", type=int, default=8000)
args, _ = parser.parse_known_args()
return args
if __name__ == "__main__":
args = parse_args()
policy_ctrl = load_policy_control(args.model_path, args.model_filename,speed_val,steer_val,stop_val)
cam = AttentionCamera("/dev/video0", model_path, model)
loop = tornado.ioloop.IOLoop.current()
control_process = Process(target=control_loop, args=(speed_val, steer_val))
capture_callback = tornado.ioloop.PeriodicCallback(lambda: cam.capture(), 10)
......@@ -259,8 +290,8 @@ if __name__ == "__main__":
lambda: CommandHandler.check_alive(), 2
)
try:
app = make_app()
app.listen(8888)
app = make_app(cam, policy_ctrl)
app.listen(args.port)
# periodic_callback = tornado.ioloop.PeriodicCallback(
# lambda: VideoHandler.send_image(cam.capture()), 100
......
......@@ -9,6 +9,7 @@ import numpy as np
import base64
from rl.ppo_continous import Agent
from rl.cardriver_env import DummyEnv
from rl.test import load_agent_from_conf
from evo2.solution import VisionTaskSolution
import torch
import abc
......@@ -16,6 +17,17 @@ import time
frame_rate = 20
def load_rl_agent(path,model):
with open(os.path.join(path, "config.yaml"), "r") as f:
settings = yaml.safe_load(f)
agent = load_agent_from_conf(settings)
agent.load_agent(os.path.join(path,model)
return agent
def load_attention_agent(path, model, overplot=False):
with open(os.path.join(path, "config.yaml"), "r") as f:
settings = yaml.safe_load(f)
......@@ -60,7 +72,7 @@ class RLPolicyController(BasePolicyController):
super(RLPolicyController, self).__init__(cam, speed, steer, stop)
dummy_env = DummyEnv()
self.agent = Agent(dummy_env)
self.agent.load_agent(path)
self.agent = load_rl_agent(path)
self.state = np.zeros((n_stacked_obs, 6, 3))
def drive(self):
......@@ -130,10 +142,22 @@ class AttentionPolicyController(BasePolicyController):
# print("Speed:", speed.value, "Steer:", steer.value)
obs = self.cam.get_frame()
action = self.agent.get_output(obs)
self.speed.value = self.speed.value*alpha+(1-alpha)*np.clip(action[0], -0.6, 0.7)
self.steer.value = np.clip(action[1]*1.0, -1.0, 1.0)
self.speed.value = self.speed.value * alpha + (1 - alpha) * np.clip(
action[0], -0.6, 0.7
)
self.steer.value = np.clip(action[1] * 1.0, -1.0, 1.0)
print("speed action:", action[0])
if self.stop.value:
self.speed.value = 0
break
time.sleep(max(1.0 / frame_rate - (time.time() - start), 0))
def load_policy_control(path, model, cam, speed, steer, stop):
with open(os.path.join(path, "config.yaml"), "r") as f:
settings = yaml.safe_load(f)
if "master" in settings.key():
return AttentionPolicyController(path, model, cam, speed, steer, stop)
else:
return RLPolicyController(path, model, cam, speed, steer, stop)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment