diff --git a/Lib/sets.py b/Lib/sets.py index 5f0f0a2d38c1..bbb93a091ebd 100644 --- a/Lib/sets.py +++ b/Lib/sets.py @@ -320,6 +320,8 @@ def _update(self, iterable): return value = True + if type(iterable) not in (list, tuple, dict, file, xrange, str): + iterable = list(iterable) it = iter(iterable) while True: try: diff --git a/Lib/test/test_sets.py b/Lib/test/test_sets.py index 840036cc654d..76b56b118d86 100644 --- a/Lib/test/test_sets.py +++ b/Lib/test/test_sets.py @@ -132,6 +132,30 @@ def setUp(self): #============================================================================== +def baditer(): + raise TypeError + yield True + +def gooditer(): + yield True + +class TestExceptionPropagation(unittest.TestCase): + """SF 628246: Set constructor should not trap iterator TypeErrors""" + + def test_instanceWithException(self): + self.assertRaises(TypeError, Set, baditer()) + + def test_instancesWithoutException(self): + """All of these iterables should load without exception.""" + Set([1,2,3]) + Set((1,2,3)) + Set({'one':1, 'two':2, 'three':3}) + Set(xrange(3)) + Set('abc') + Set(gooditer()) + +#============================================================================== + class TestSetOfSets(unittest.TestCase): def test_constructor(self): inner = Set([1]) @@ -604,6 +628,7 @@ def setUp(self): def makeAllTests(): suite = unittest.TestSuite() for klass in (TestSetOfSets, + TestExceptionPropagation, TestBasicOpsEmpty, TestBasicOpsSingleton, TestBasicOpsTuple,