Improved Python agent performance (#5555)
* Improved agent performance * CHANGELOG
This commit is contained in:
parent
eef45bc208
commit
bf2815ec58
|
@ -14,6 +14,7 @@
|
|||
* Added check to avoid adding procedural trigger boxes inside intersections.
|
||||
* Python agents now accept a carla.Map and GlobalRoutePlanner instances as inputs, avoiding the need to recompute them.
|
||||
* Python agents now have a function to lane change.
|
||||
* Improved Python agents performance for large maps.
|
||||
* Fix a bug at `Map.get_topology()`, causing lanes with no successors to not be part of it.
|
||||
* Added new ConstantVelocityAgent
|
||||
* Added new parameter to the TrafficManager, `set_desired_speed`, to set a vehicle's speed.
|
||||
|
|
|
@ -89,6 +89,10 @@ class BasicAgent(object):
|
|||
else:
|
||||
self._global_planner = GlobalRoutePlanner(self._map, self._sampling_resolution)
|
||||
|
||||
# Get the static elements of the scene
|
||||
self._lights_list = self._world.get_actors().filter("*traffic_light*")
|
||||
self._lights_map = {} # Dictionary mapping a traffic light to a wp corrspoing to its trigger volume location
|
||||
|
||||
def add_emergency_stop(self, control):
|
||||
"""
|
||||
Overwrites the throttle a brake values of a control to perform an emergency stop.
|
||||
|
@ -178,9 +182,7 @@ class BasicAgent(object):
|
|||
hazard_detected = False
|
||||
|
||||
# Retrieve all relevant actors
|
||||
actor_list = self._world.get_actors()
|
||||
vehicle_list = actor_list.filter("*vehicle*")
|
||||
lights_list = actor_list.filter("*traffic_light*")
|
||||
vehicle_list = self._world.get_actors().filter("*vehicle*")
|
||||
|
||||
vehicle_speed = get_speed(self._vehicle) / 3.6
|
||||
|
||||
|
@ -192,7 +194,7 @@ class BasicAgent(object):
|
|||
|
||||
# Check if the vehicle is affected by a red traffic light
|
||||
max_tlight_distance = self._base_tlight_threshold + vehicle_speed
|
||||
affected_by_tlight, _ = self._affected_by_traffic_light(lights_list, max_tlight_distance)
|
||||
affected_by_tlight, _ = self._affected_by_traffic_light(self._lights_list, max_tlight_distance)
|
||||
if affected_by_tlight:
|
||||
hazard_detected = True
|
||||
|
||||
|
@ -268,14 +270,21 @@ class BasicAgent(object):
|
|||
ego_vehicle_waypoint = self._map.get_waypoint(ego_vehicle_location)
|
||||
|
||||
for traffic_light in lights_list:
|
||||
object_location = get_trafficlight_trigger_location(traffic_light)
|
||||
object_waypoint = self._map.get_waypoint(object_location)
|
||||
if traffic_light.id in self._lights_map:
|
||||
trigger_wp = self._lights_map[traffic_light.id]
|
||||
else:
|
||||
trigger_location = get_trafficlight_trigger_location(traffic_light)
|
||||
trigger_wp = self._map.get_waypoint(trigger_location)
|
||||
self._lights_map[traffic_light.id] = trigger_wp
|
||||
|
||||
if object_waypoint.road_id != ego_vehicle_waypoint.road_id:
|
||||
if trigger_wp.transform.location.distance(ego_vehicle_location) > max_distance:
|
||||
continue
|
||||
|
||||
if trigger_wp.road_id != ego_vehicle_waypoint.road_id:
|
||||
continue
|
||||
|
||||
ve_dir = ego_vehicle_waypoint.transform.get_forward_vector()
|
||||
wp_dir = object_waypoint.transform.get_forward_vector()
|
||||
wp_dir = trigger_wp.transform.get_forward_vector()
|
||||
dot_ve_wp = ve_dir.x * wp_dir.x + ve_dir.y * wp_dir.y + ve_dir.z * wp_dir.z
|
||||
|
||||
if dot_ve_wp < 0:
|
||||
|
@ -284,7 +293,7 @@ class BasicAgent(object):
|
|||
if traffic_light.state != carla.TrafficLightState.Red:
|
||||
continue
|
||||
|
||||
if is_within_distance(object_waypoint.transform, self._vehicle.get_transform(), max_distance, [0, 90]):
|
||||
if is_within_distance(trigger_wp.transform, self._vehicle.get_transform(), max_distance, [0, 90]):
|
||||
self._last_traffic_light = traffic_light
|
||||
return (True, traffic_light)
|
||||
|
||||
|
@ -326,6 +335,9 @@ class BasicAgent(object):
|
|||
|
||||
for target_vehicle in vehicle_list:
|
||||
target_transform = target_vehicle.get_transform()
|
||||
if target_transform.location.distance(ego_transform.location) > max_distance:
|
||||
continue
|
||||
|
||||
target_wpt = self._map.get_waypoint(target_transform.location, lane_type=carla.LaneType.Any)
|
||||
|
||||
# Simplified version for outside junctions
|
||||
|
|
|
@ -714,7 +714,7 @@ def game_loop(args):
|
|||
random.seed(args.seed)
|
||||
|
||||
client = carla.Client(args.host, args.port)
|
||||
client.set_timeout(4.0)
|
||||
client.set_timeout(60.0)
|
||||
|
||||
traffic_manager = client.get_trafficmanager()
|
||||
sim_world = client.get_world()
|
||||
|
@ -736,11 +736,13 @@ def game_loop(args):
|
|||
controller = KeyboardControl(world)
|
||||
if args.agent == "Basic":
|
||||
agent = BasicAgent(world.player, 30)
|
||||
agent.follow_speed_limits(True)
|
||||
elif args.agent == "Constant":
|
||||
agent = ConstantVelocityAgent(world.player, 30)
|
||||
ground_loc = world.world.ground_projection(world.player.get_location(), 5)
|
||||
if ground_loc:
|
||||
world.player.set_location(ground_loc.location + carla.Location(z=0.01))
|
||||
agent.follow_speed_limits(True)
|
||||
elif args.agent == "Behavior":
|
||||
agent = BehaviorAgent(world.player, behavior=args.behavior)
|
||||
|
||||
|
|
Loading…
Reference in New Issue