diff --git a/core/roslib/src/roslib/rostime.py b/core/roslib/src/roslib/rostime.py index c0807553..a2fcb7fb 100644 --- a/core/roslib/src/roslib/rostime.py +++ b/core/roslib/src/roslib/rostime.py @@ -411,15 +411,68 @@ class Duration(TVal): def __mul__(self, val): """ - Multiply this duration by an integer + Multiply this duration by an integer or float @param val: multiplication factor - @type val: int + @type val: int/float @return: Duration multiplied by val @rtype: L{Duration} """ - if not type(val) == int: + t = type(val) + if t in (int, long): + return Duration(self.secs * val, self.nsecs * val) + elif t == float: + return Duration.from_sec(self.to_sec() * val) + else: + return NotImplemented + + def __floordiv__(self, val): + """ + Floor divide this duration by an integer or float + @param val: division factor + @type val: int/float + @return: Duration multiplied by val + @rtype: L{Duration} + """ + t = type(val) + if t in (int, long): + return Duration(self.secs // val, self.nsecs // val) + elif t == float: + return Duration.from_sec(self.to_sec() // val) + else: + return NotImplemented + + def __div__(self, val): + """ + Divide this duration by an integer or float + @param val: division factor + @type val: int/float + @return: Duration multiplied by val + @rtype: L{Duration} + """ + # unlike __floordiv__, this uses true div for float arg + t = type(val) + if t in (int, long): + return Duration(self.secs // val, self.nsecs // val) + elif t == float: + return Duration.from_sec(self.to_sec() / val) + else: + return NotImplemented + + def __truediv__(self, val): + """ + Divide this duration by an integer or float + @param val: division factor + @type val: int/float + @return: Duration multiplied by val + @rtype: L{Duration} + """ + t = type(val) + if t in (int, long): + return Duration(self.secs / val, self.nsecs / val) + elif t == float: + return Duration.from_sec(self.to_sec() / val) + else: return NotImplemented - return Duration(self.secs * val, self.nsecs * val) def __cmp__(self, other): if not isinstance(other, Duration): diff --git a/test/test_roslib/test/test_roslib_rostime.py b/test/test_roslib/test/test_roslib_rostime.py index c3f890d8..a36f7890 100644 --- a/test/test_roslib/test/test_roslib_rostime.py +++ b/test/test_roslib/test/test_roslib_rostime.py @@ -409,6 +409,31 @@ class RostimeTest(unittest.TestCase): self.fail("should have thrown value error") except ValueError: pass + # Test mul + self.assertEquals(Duration(4), Duration(2) * 2) + self.assertEquals(Duration(4), Duration(2) * 2.) + self.assertEquals(Duration(10), Duration(4) * 2.5) + self.assertEquals(Duration(4, 8), Duration(2, 4) * 2) + v = Duration(4, 8) - (Duration(2, 4) * 2.) + self.assert_(abs(v.to_nsec()) < 100) + v = Duration(5, 10) - (Duration(2, 4) * 2.5) + self.assert_(abs(v.to_nsec()) < 100) + + # Test div + self.assertEquals(Duration(4), Duration(8) / 2) + self.assertEquals(Duration(4), Duration(8) / 2.) + self.assertEquals(Duration(4), Duration(8) // 2) + self.assertEquals(Duration(4), Duration(8) // 2.) + self.assertEquals(Duration(4), Duration(9) // 2) + self.assertEquals(Duration(4), Duration(9) // 2.) + self.assertEquals(Duration(4, 2), Duration(8, 4) / 2) + v = Duration(4, 2) - (Duration(8, 4) / 2.) + self.assert_(abs(v.to_nsec()) < 100) + + self.assertEquals(Duration(4, 2), Duration(8, 4) // 2) + self.assertEquals(Duration(4, 2), Duration(9, 5) // 2) + v = Duration(4, 2) - (Duration(9, 5) // 2.) + self.assert_(abs(v.to_nsec()) < 100) if __name__ == '__main__': rostest.unitrun('test_roslib', 'test_rostime', RostimeTest, coverage_packages=['roslib.rostime'])