Improved Python agent performance (#5555)

* Improved agent performance

* CHANGELOG
This commit is contained in:
glopezdiest 2022-07-01 18:30:56 +02:00 committed by GitHub
parent eef45bc208
commit bf2815ec58
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 25 additions and 10 deletions

View File

@ -14,6 +14,7 @@
* Added check to avoid adding procedural trigger boxes inside intersections.
* Python agents now accept a carla.Map and GlobalRoutePlanner instances as inputs, avoiding the need to recompute them.
* Python agents now have a function to lane change.
* Improved Python agents performance for large maps.
* Fix a bug at `Map.get_topology()`, causing lanes with no successors to not be part of it.
* Added new ConstantVelocityAgent
* Added new parameter to the TrafficManager, `set_desired_speed`, to set a vehicle's speed.

View File

@ -89,6 +89,10 @@ class BasicAgent(object):
else:
self._global_planner = GlobalRoutePlanner(self._map, self._sampling_resolution)
# Get the static elements of the scene
self._lights_list = self._world.get_actors().filter("*traffic_light*")
self._lights_map = {} # Dictionary mapping a traffic light to a wp corrspoing to its trigger volume location
def add_emergency_stop(self, control):
"""
Overwrites the throttle a brake values of a control to perform an emergency stop.
@ -178,9 +182,7 @@ class BasicAgent(object):
hazard_detected = False
# Retrieve all relevant actors
actor_list = self._world.get_actors()
vehicle_list = actor_list.filter("*vehicle*")
lights_list = actor_list.filter("*traffic_light*")
vehicle_list = self._world.get_actors().filter("*vehicle*")
vehicle_speed = get_speed(self._vehicle) / 3.6
@ -192,7 +194,7 @@ class BasicAgent(object):
# Check if the vehicle is affected by a red traffic light
max_tlight_distance = self._base_tlight_threshold + vehicle_speed
affected_by_tlight, _ = self._affected_by_traffic_light(lights_list, max_tlight_distance)
affected_by_tlight, _ = self._affected_by_traffic_light(self._lights_list, max_tlight_distance)
if affected_by_tlight:
hazard_detected = True
@ -268,14 +270,21 @@ class BasicAgent(object):
ego_vehicle_waypoint = self._map.get_waypoint(ego_vehicle_location)
for traffic_light in lights_list:
object_location = get_trafficlight_trigger_location(traffic_light)
object_waypoint = self._map.get_waypoint(object_location)
if traffic_light.id in self._lights_map:
trigger_wp = self._lights_map[traffic_light.id]
else:
trigger_location = get_trafficlight_trigger_location(traffic_light)
trigger_wp = self._map.get_waypoint(trigger_location)
self._lights_map[traffic_light.id] = trigger_wp
if object_waypoint.road_id != ego_vehicle_waypoint.road_id:
if trigger_wp.transform.location.distance(ego_vehicle_location) > max_distance:
continue
if trigger_wp.road_id != ego_vehicle_waypoint.road_id:
continue
ve_dir = ego_vehicle_waypoint.transform.get_forward_vector()
wp_dir = object_waypoint.transform.get_forward_vector()
wp_dir = trigger_wp.transform.get_forward_vector()
dot_ve_wp = ve_dir.x * wp_dir.x + ve_dir.y * wp_dir.y + ve_dir.z * wp_dir.z
if dot_ve_wp < 0:
@ -284,7 +293,7 @@ class BasicAgent(object):
if traffic_light.state != carla.TrafficLightState.Red:
continue
if is_within_distance(object_waypoint.transform, self._vehicle.get_transform(), max_distance, [0, 90]):
if is_within_distance(trigger_wp.transform, self._vehicle.get_transform(), max_distance, [0, 90]):
self._last_traffic_light = traffic_light
return (True, traffic_light)
@ -326,6 +335,9 @@ class BasicAgent(object):
for target_vehicle in vehicle_list:
target_transform = target_vehicle.get_transform()
if target_transform.location.distance(ego_transform.location) > max_distance:
continue
target_wpt = self._map.get_waypoint(target_transform.location, lane_type=carla.LaneType.Any)
# Simplified version for outside junctions

View File

@ -714,7 +714,7 @@ def game_loop(args):
random.seed(args.seed)
client = carla.Client(args.host, args.port)
client.set_timeout(4.0)
client.set_timeout(60.0)
traffic_manager = client.get_trafficmanager()
sim_world = client.get_world()
@ -736,11 +736,13 @@ def game_loop(args):
controller = KeyboardControl(world)
if args.agent == "Basic":
agent = BasicAgent(world.player, 30)
agent.follow_speed_limits(True)
elif args.agent == "Constant":
agent = ConstantVelocityAgent(world.player, 30)
ground_loc = world.world.ground_projection(world.player.get_location(), 5)
if ground_loc:
world.player.set_location(ground_loc.location + carla.Location(z=0.01))
agent.follow_speed_limits(True)
elif args.agent == "Behavior":
agent = BehaviorAgent(world.player, behavior=args.behavior)