店播爬取Python脚本

message_test.py 107KB


  1. #! /usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Protocol Buffers - Google's data interchange format
  5. # Copyright 2008 Google Inc. All rights reserved.
  6. # https://developers.google.com/protocol-buffers/
  7. #
  8. # Redistribution and use in source and binary forms, with or without
  9. # modification, are permitted provided that the following conditions are
  10. # met:
  11. #
  12. # * Redistributions of source code must retain the above copyright
  13. # notice, this list of conditions and the following disclaimer.
  14. # * Redistributions in binary form must reproduce the above
  15. # copyright notice, this list of conditions and the following disclaimer
  16. # in the documentation and/or other materials provided with the
  17. # distribution.
  18. # * Neither the name of Google Inc. nor the names of its
  19. # contributors may be used to endorse or promote products derived from
  20. # this software without specific prior written permission.
  21. #
  22. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
  23. # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
  24. # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
  25. # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
  26. # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
  27. # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
  28. # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
  29. # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
  30. # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  31. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  32. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  33. """Tests python protocol buffers against the golden message.
  34. Note that the golden messages exercise every known field type, thus this
  35. test ends up exercising and verifying nearly all of the parsing and
  36. serialization code in the whole library.
  37. TODO(kenton): Merge with wire_format_test? It doesn't make a whole lot of
  38. sense to call this a test of the "message" module, which only declares an
  39. abstract interface.
  40. """
  41. __author__ = 'gps@google.com (Gregory P. Smith)'
  42. import copy
  43. import math
  44. import operator
  45. import pickle
  46. import pydoc
  47. import six
  48. import sys
  49. import warnings
  50. try:
  51. # Since python 3
  52. import collections.abc as collections_abc
  53. except ImportError:
  54. # Won't work after python 3.8
  55. import collections as collections_abc
  56. try:
  57. import unittest2 as unittest # PY26
  58. except ImportError:
  59. import unittest
  60. try:
  61. cmp # Python 2
  62. except NameError:
  63. cmp = lambda x, y: (x > y) - (x < y) # Python 3
  64. from google.protobuf import map_proto2_unittest_pb2
  65. from google.protobuf import map_unittest_pb2
  66. from google.protobuf import unittest_pb2
  67. from google.protobuf import unittest_proto3_arena_pb2
  68. from google.protobuf import descriptor_pb2
  69. from google.protobuf import descriptor_pool
  70. from google.protobuf import message_factory
  71. from google.protobuf import text_format
  72. from google.protobuf.internal import api_implementation
  73. from google.protobuf.internal import encoder
  74. from google.protobuf.internal import more_extensions_pb2
  75. from google.protobuf.internal import packed_field_test_pb2
  76. from google.protobuf.internal import test_util
  77. from google.protobuf.internal import test_proto3_optional_pb2
  78. from google.protobuf.internal import testing_refleaks
  79. from google.protobuf import message
  80. from google.protobuf.internal import _parameterized
  81. UCS2_MAXUNICODE = 65535
  82. if six.PY3:
  83. long = int
  84. # Python pre-2.6 does not have isinf() or isnan() functions, so we have
  85. # to provide our own.
  86. def isnan(val):
  87. # NaN is never equal to itself.
  88. return val != val
  89. def isinf(val):
  90. # Infinity times zero equals NaN.
  91. return not isnan(val) and isnan(val * 0)
  92. def IsPosInf(val):
  93. return isinf(val) and (val > 0)
  94. def IsNegInf(val):
  95. return isinf(val) and (val < 0)
  96. warnings.simplefilter('error', DeprecationWarning)
  97. @_parameterized.named_parameters(
  98. ('_proto2', unittest_pb2),
  99. ('_proto3', unittest_proto3_arena_pb2))
  100. @testing_refleaks.TestCase
  101. class MessageTest(unittest.TestCase):
  102. def testBadUtf8String(self, message_module):
  103. if api_implementation.Type() != 'python':
  104. self.skipTest("Skipping testBadUtf8String, currently only the python "
  105. "api implementation raises UnicodeDecodeError when a "
  106. "string field contains bad utf-8.")
  107. bad_utf8_data = test_util.GoldenFileData('bad_utf8_string')
  108. with self.assertRaises(UnicodeDecodeError) as context:
  109. message_module.TestAllTypes.FromString(bad_utf8_data)
  110. self.assertIn('TestAllTypes.optional_string', str(context.exception))
  111. def testGoldenMessage(self, message_module):
  112. # Proto3 doesn't have the "default_foo" members or foreign enums,
  113. # and doesn't preserve unknown fields, so for proto3 we use a golden
  114. # message that doesn't have these fields set.
  115. if message_module is unittest_pb2:
  116. golden_data = test_util.GoldenFileData(
  117. 'golden_message_oneof_implemented')
  118. else:
  119. golden_data = test_util.GoldenFileData('golden_message_proto3')
  120. golden_message = message_module.TestAllTypes()
  121. golden_message.ParseFromString(golden_data)
  122. if message_module is unittest_pb2:
  123. test_util.ExpectAllFieldsSet(self, golden_message)
  124. self.assertEqual(golden_data, golden_message.SerializeToString())
  125. golden_copy = copy.deepcopy(golden_message)
  126. self.assertEqual(golden_data, golden_copy.SerializeToString())
  127. def testGoldenPackedMessage(self, message_module):
  128. golden_data = test_util.GoldenFileData('golden_packed_fields_message')
  129. golden_message = message_module.TestPackedTypes()
  130. parsed_bytes = golden_message.ParseFromString(golden_data)
  131. all_set = message_module.TestPackedTypes()
  132. test_util.SetAllPackedFields(all_set)
  133. self.assertEqual(parsed_bytes, len(golden_data))
  134. self.assertEqual(all_set, golden_message)
  135. self.assertEqual(golden_data, all_set.SerializeToString())
  136. golden_copy = copy.deepcopy(golden_message)
  137. self.assertEqual(golden_data, golden_copy.SerializeToString())
  138. def testParseErrors(self, message_module):
  139. msg = message_module.TestAllTypes()
  140. self.assertRaises(TypeError, msg.FromString, 0)
  141. self.assertRaises(Exception, msg.FromString, '0')
  142. # TODO(jieluo): Fix cpp extension to raise error instead of warning.
  143. # b/27494216
  144. end_tag = encoder.TagBytes(1, 4)
  145. if api_implementation.Type() == 'python':
  146. with self.assertRaises(message.DecodeError) as context:
  147. msg.FromString(end_tag)
  148. self.assertEqual('Unexpected end-group tag.', str(context.exception))
  149. # Field number 0 is illegal.
  150. self.assertRaises(message.DecodeError, msg.FromString, b'\3\4')
  151. def testDeterminismParameters(self, message_module):
  152. # This message is always deterministically serialized, even if determinism
  153. # is disabled, so we can use it to verify that all the determinism
  154. # parameters work correctly.
  155. golden_data = (b'\xe2\x02\nOne string'
  156. b'\xe2\x02\nTwo string'
  157. b'\xe2\x02\nRed string'
  158. b'\xe2\x02\x0bBlue string')
  159. golden_message = message_module.TestAllTypes()
  160. golden_message.repeated_string.extend([
  161. 'One string',
  162. 'Two string',
  163. 'Red string',
  164. 'Blue string',
  165. ])
  166. self.assertEqual(golden_data,
  167. golden_message.SerializeToString(deterministic=None))
  168. self.assertEqual(golden_data,
  169. golden_message.SerializeToString(deterministic=False))
  170. self.assertEqual(golden_data,
  171. golden_message.SerializeToString(deterministic=True))
  172. class BadArgError(Exception):
  173. pass
  174. class BadArg(object):
  175. def __nonzero__(self):
  176. raise BadArgError()
  177. def __bool__(self):
  178. raise BadArgError()
  179. with self.assertRaises(BadArgError):
  180. golden_message.SerializeToString(deterministic=BadArg())
  181. def testPickleSupport(self, message_module):
  182. golden_data = test_util.GoldenFileData('golden_message')
  183. golden_message = message_module.TestAllTypes()
  184. golden_message.ParseFromString(golden_data)
  185. pickled_message = pickle.dumps(golden_message)
  186. unpickled_message = pickle.loads(pickled_message)
  187. self.assertEqual(unpickled_message, golden_message)
  188. def testPickleNestedMessage(self, message_module):
  189. golden_message = message_module.TestPickleNestedMessage.NestedMessage(bb=1)
  190. pickled_message = pickle.dumps(golden_message)
  191. unpickled_message = pickle.loads(pickled_message)
  192. self.assertEqual(unpickled_message, golden_message)
  193. def testPickleNestedNestedMessage(self, message_module):
  194. cls = message_module.TestPickleNestedMessage.NestedMessage
  195. golden_message = cls.NestedNestedMessage(cc=1)
  196. pickled_message = pickle.dumps(golden_message)
  197. unpickled_message = pickle.loads(pickled_message)
  198. self.assertEqual(unpickled_message, golden_message)
  199. def testPositiveInfinity(self, message_module):
  200. if message_module is unittest_pb2:
  201. golden_data = (b'\x5D\x00\x00\x80\x7F'
  202. b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F'
  203. b'\xCD\x02\x00\x00\x80\x7F'
  204. b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\x7F')
  205. else:
  206. golden_data = (b'\x5D\x00\x00\x80\x7F'
  207. b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F'
  208. b'\xCA\x02\x04\x00\x00\x80\x7F'
  209. b'\xD2\x02\x08\x00\x00\x00\x00\x00\x00\xF0\x7F')
  210. golden_message = message_module.TestAllTypes()
  211. golden_message.ParseFromString(golden_data)
  212. self.assertTrue(IsPosInf(golden_message.optional_float))
  213. self.assertTrue(IsPosInf(golden_message.optional_double))
  214. self.assertTrue(IsPosInf(golden_message.repeated_float[0]))
  215. self.assertTrue(IsPosInf(golden_message.repeated_double[0]))
  216. self.assertEqual(golden_data, golden_message.SerializeToString())
  217. def testNegativeInfinity(self, message_module):
  218. if message_module is unittest_pb2:
  219. golden_data = (b'\x5D\x00\x00\x80\xFF'
  220. b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF'
  221. b'\xCD\x02\x00\x00\x80\xFF'
  222. b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\xFF')
  223. else:
  224. golden_data = (b'\x5D\x00\x00\x80\xFF'
  225. b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF'
  226. b'\xCA\x02\x04\x00\x00\x80\xFF'
  227. b'\xD2\x02\x08\x00\x00\x00\x00\x00\x00\xF0\xFF')
  228. golden_message = message_module.TestAllTypes()
  229. golden_message.ParseFromString(golden_data)
  230. self.assertTrue(IsNegInf(golden_message.optional_float))
  231. self.assertTrue(IsNegInf(golden_message.optional_double))
  232. self.assertTrue(IsNegInf(golden_message.repeated_float[0]))
  233. self.assertTrue(IsNegInf(golden_message.repeated_double[0]))
  234. self.assertEqual(golden_data, golden_message.SerializeToString())
  235. def testNotANumber(self, message_module):
  236. golden_data = (b'\x5D\x00\x00\xC0\x7F'
  237. b'\x61\x00\x00\x00\x00\x00\x00\xF8\x7F'
  238. b'\xCD\x02\x00\x00\xC0\x7F'
  239. b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF8\x7F')
  240. golden_message = message_module.TestAllTypes()
  241. golden_message.ParseFromString(golden_data)
  242. self.assertTrue(isnan(golden_message.optional_float))
  243. self.assertTrue(isnan(golden_message.optional_double))
  244. self.assertTrue(isnan(golden_message.repeated_float[0]))
  245. self.assertTrue(isnan(golden_message.repeated_double[0]))
  246. # The protocol buffer may serialize to any one of multiple different
  247. # representations of a NaN. Rather than verify a specific representation,
  248. # verify the serialized string can be converted into a correctly
  249. # behaving protocol buffer.
  250. serialized = golden_message.SerializeToString()
  251. message = message_module.TestAllTypes()
  252. message.ParseFromString(serialized)
  253. self.assertTrue(isnan(message.optional_float))
  254. self.assertTrue(isnan(message.optional_double))
  255. self.assertTrue(isnan(message.repeated_float[0]))
  256. self.assertTrue(isnan(message.repeated_double[0]))
  257. def testPositiveInfinityPacked(self, message_module):
  258. golden_data = (b'\xA2\x06\x04\x00\x00\x80\x7F'
  259. b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\x7F')
  260. golden_message = message_module.TestPackedTypes()
  261. golden_message.ParseFromString(golden_data)
  262. self.assertTrue(IsPosInf(golden_message.packed_float[0]))
  263. self.assertTrue(IsPosInf(golden_message.packed_double[0]))
  264. self.assertEqual(golden_data, golden_message.SerializeToString())
  265. def testNegativeInfinityPacked(self, message_module):
  266. golden_data = (b'\xA2\x06\x04\x00\x00\x80\xFF'
  267. b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\xFF')
  268. golden_message = message_module.TestPackedTypes()
  269. golden_message.ParseFromString(golden_data)
  270. self.assertTrue(IsNegInf(golden_message.packed_float[0]))
  271. self.assertTrue(IsNegInf(golden_message.packed_double[0]))
  272. self.assertEqual(golden_data, golden_message.SerializeToString())
  273. def testNotANumberPacked(self, message_module):
  274. golden_data = (b'\xA2\x06\x04\x00\x00\xC0\x7F'
  275. b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF8\x7F')
  276. golden_message = message_module.TestPackedTypes()
  277. golden_message.ParseFromString(golden_data)
  278. self.assertTrue(isnan(golden_message.packed_float[0]))
  279. self.assertTrue(isnan(golden_message.packed_double[0]))
  280. serialized = golden_message.SerializeToString()
  281. message = message_module.TestPackedTypes()
  282. message.ParseFromString(serialized)
  283. self.assertTrue(isnan(message.packed_float[0]))
  284. self.assertTrue(isnan(message.packed_double[0]))
  285. def testExtremeFloatValues(self, message_module):
  286. message = message_module.TestAllTypes()
  287. # Most positive exponent, no significand bits set.
  288. kMostPosExponentNoSigBits = math.pow(2, 127)
  289. message.optional_float = kMostPosExponentNoSigBits
  290. message.ParseFromString(message.SerializeToString())
  291. self.assertTrue(message.optional_float == kMostPosExponentNoSigBits)
  292. # Most positive exponent, one significand bit set.
  293. kMostPosExponentOneSigBit = 1.5 * math.pow(2, 127)
  294. message.optional_float = kMostPosExponentOneSigBit
  295. message.ParseFromString(message.SerializeToString())
  296. self.assertTrue(message.optional_float == kMostPosExponentOneSigBit)
  297. # Repeat last two cases with values of same magnitude, but negative.
  298. message.optional_float = -kMostPosExponentNoSigBits
  299. message.ParseFromString(message.SerializeToString())
  300. self.assertTrue(message.optional_float == -kMostPosExponentNoSigBits)
  301. message.optional_float = -kMostPosExponentOneSigBit
  302. message.ParseFromString(message.SerializeToString())
  303. self.assertTrue(message.optional_float == -kMostPosExponentOneSigBit)
  304. # Most negative exponent, no significand bits set.
  305. kMostNegExponentNoSigBits = math.pow(2, -127)
  306. message.optional_float = kMostNegExponentNoSigBits
  307. message.ParseFromString(message.SerializeToString())
  308. self.assertTrue(message.optional_float == kMostNegExponentNoSigBits)
  309. # Most negative exponent, one significand bit set.
  310. kMostNegExponentOneSigBit = 1.5 * math.pow(2, -127)
  311. message.optional_float = kMostNegExponentOneSigBit
  312. message.ParseFromString(message.SerializeToString())
  313. self.assertTrue(message.optional_float == kMostNegExponentOneSigBit)
  314. # Repeat last two cases with values of the same magnitude, but negative.
  315. message.optional_float = -kMostNegExponentNoSigBits
  316. message.ParseFromString(message.SerializeToString())
  317. self.assertTrue(message.optional_float == -kMostNegExponentNoSigBits)
  318. message.optional_float = -kMostNegExponentOneSigBit
  319. message.ParseFromString(message.SerializeToString())
  320. self.assertTrue(message.optional_float == -kMostNegExponentOneSigBit)
  321. # Max 4 bytes float value
  322. max_float = float.fromhex('0x1.fffffep+127')
  323. message.optional_float = max_float
  324. self.assertAlmostEqual(message.optional_float, max_float)
  325. serialized_data = message.SerializeToString()
  326. message.ParseFromString(serialized_data)
  327. self.assertAlmostEqual(message.optional_float, max_float)
  328. # Test set double to float field.
  329. message.optional_float = 3.4028235e+39
  330. self.assertEqual(message.optional_float, float('inf'))
  331. serialized_data = message.SerializeToString()
  332. message.ParseFromString(serialized_data)
  333. self.assertEqual(message.optional_float, float('inf'))
  334. message.optional_float = -3.4028235e+39
  335. self.assertEqual(message.optional_float, float('-inf'))
  336. message.optional_float = 1.4028235e-39
  337. self.assertAlmostEqual(message.optional_float, 1.4028235e-39)
  338. def testExtremeDoubleValues(self, message_module):
  339. message = message_module.TestAllTypes()
  340. # Most positive exponent, no significand bits set.
  341. kMostPosExponentNoSigBits = math.pow(2, 1023)
  342. message.optional_double = kMostPosExponentNoSigBits
  343. message.ParseFromString(message.SerializeToString())
  344. self.assertTrue(message.optional_double == kMostPosExponentNoSigBits)
  345. # Most positive exponent, one significand bit set.
  346. kMostPosExponentOneSigBit = 1.5 * math.pow(2, 1023)
  347. message.optional_double = kMostPosExponentOneSigBit
  348. message.ParseFromString(message.SerializeToString())
  349. self.assertTrue(message.optional_double == kMostPosExponentOneSigBit)
  350. # Repeat last two cases with values of same magnitude, but negative.
  351. message.optional_double = -kMostPosExponentNoSigBits
  352. message.ParseFromString(message.SerializeToString())
  353. self.assertTrue(message.optional_double == -kMostPosExponentNoSigBits)
  354. message.optional_double = -kMostPosExponentOneSigBit
  355. message.ParseFromString(message.SerializeToString())
  356. self.assertTrue(message.optional_double == -kMostPosExponentOneSigBit)
  357. # Most negative exponent, no significand bits set.
  358. kMostNegExponentNoSigBits = math.pow(2, -1023)
  359. message.optional_double = kMostNegExponentNoSigBits
  360. message.ParseFromString(message.SerializeToString())
  361. self.assertTrue(message.optional_double == kMostNegExponentNoSigBits)
  362. # Most negative exponent, one significand bit set.
  363. kMostNegExponentOneSigBit = 1.5 * math.pow(2, -1023)
  364. message.optional_double = kMostNegExponentOneSigBit
  365. message.ParseFromString(message.SerializeToString())
  366. self.assertTrue(message.optional_double == kMostNegExponentOneSigBit)
  367. # Repeat last two cases with values of the same magnitude, but negative.
  368. message.optional_double = -kMostNegExponentNoSigBits
  369. message.ParseFromString(message.SerializeToString())
  370. self.assertTrue(message.optional_double == -kMostNegExponentNoSigBits)
  371. message.optional_double = -kMostNegExponentOneSigBit
  372. message.ParseFromString(message.SerializeToString())
  373. self.assertTrue(message.optional_double == -kMostNegExponentOneSigBit)
  374. def testFloatPrinting(self, message_module):
  375. message = message_module.TestAllTypes()
  376. message.optional_float = 2.0
  377. self.assertEqual(str(message), 'optional_float: 2.0\n')
  378. def testHighPrecisionFloatPrinting(self, message_module):
  379. msg = message_module.TestAllTypes()
  380. msg.optional_float = 0.12345678912345678
  381. old_float = msg.optional_float
  382. msg.ParseFromString(msg.SerializeToString())
  383. self.assertEqual(old_float, msg.optional_float)
  384. def testHighPrecisionDoublePrinting(self, message_module):
  385. msg = message_module.TestAllTypes()
  386. msg.optional_double = 0.12345678912345678
  387. if sys.version_info >= (3,):
  388. self.assertEqual(str(msg), 'optional_double: 0.12345678912345678\n')
  389. else:
  390. self.assertEqual(str(msg), 'optional_double: 0.123456789123\n')
  391. def testUnknownFieldPrinting(self, message_module):
  392. populated = message_module.TestAllTypes()
  393. test_util.SetAllNonLazyFields(populated)
  394. empty = message_module.TestEmptyMessage()
  395. empty.ParseFromString(populated.SerializeToString())
  396. self.assertEqual(str(empty), '')
  397. def testAppendRepeatedCompositeField(self, message_module):
  398. msg = message_module.TestAllTypes()
  399. msg.repeated_nested_message.append(
  400. message_module.TestAllTypes.NestedMessage(bb=1))
  401. nested = message_module.TestAllTypes.NestedMessage(bb=2)
  402. msg.repeated_nested_message.append(nested)
  403. try:
  404. msg.repeated_nested_message.append(1)
  405. except TypeError:
  406. pass
  407. self.assertEqual(2, len(msg.repeated_nested_message))
  408. self.assertEqual([1, 2],
  409. [m.bb for m in msg.repeated_nested_message])
  410. def testInsertRepeatedCompositeField(self, message_module):
  411. msg = message_module.TestAllTypes()
  412. msg.repeated_nested_message.insert(
  413. -1, message_module.TestAllTypes.NestedMessage(bb=1))
  414. sub_msg = msg.repeated_nested_message[0]
  415. msg.repeated_nested_message.insert(
  416. 0, message_module.TestAllTypes.NestedMessage(bb=2))
  417. msg.repeated_nested_message.insert(
  418. 99, message_module.TestAllTypes.NestedMessage(bb=3))
  419. msg.repeated_nested_message.insert(
  420. -2, message_module.TestAllTypes.NestedMessage(bb=-1))
  421. msg.repeated_nested_message.insert(
  422. -1000, message_module.TestAllTypes.NestedMessage(bb=-1000))
  423. try:
  424. msg.repeated_nested_message.insert(1, 999)
  425. except TypeError:
  426. pass
  427. self.assertEqual(5, len(msg.repeated_nested_message))
  428. self.assertEqual([-1000, 2, -1, 1, 3],
  429. [m.bb for m in msg.repeated_nested_message])
  430. self.assertEqual(str(msg),
  431. 'repeated_nested_message {\n'
  432. ' bb: -1000\n'
  433. '}\n'
  434. 'repeated_nested_message {\n'
  435. ' bb: 2\n'
  436. '}\n'
  437. 'repeated_nested_message {\n'
  438. ' bb: -1\n'
  439. '}\n'
  440. 'repeated_nested_message {\n'
  441. ' bb: 1\n'
  442. '}\n'
  443. 'repeated_nested_message {\n'
  444. ' bb: 3\n'
  445. '}\n')
  446. self.assertEqual(sub_msg.bb, 1)
  447. def testMergeFromRepeatedField(self, message_module):
  448. msg = message_module.TestAllTypes()
  449. msg.repeated_int32.append(1)
  450. msg.repeated_int32.append(3)
  451. msg.repeated_nested_message.add(bb=1)
  452. msg.repeated_nested_message.add(bb=2)
  453. other_msg = message_module.TestAllTypes()
  454. other_msg.repeated_nested_message.add(bb=3)
  455. other_msg.repeated_nested_message.add(bb=4)
  456. other_msg.repeated_int32.append(5)
  457. other_msg.repeated_int32.append(7)
  458. msg.repeated_int32.MergeFrom(other_msg.repeated_int32)
  459. self.assertEqual(4, len(msg.repeated_int32))
  460. msg.repeated_nested_message.MergeFrom(other_msg.repeated_nested_message)
  461. self.assertEqual([1, 2, 3, 4],
  462. [m.bb for m in msg.repeated_nested_message])
  463. def testAddWrongRepeatedNestedField(self, message_module):
  464. msg = message_module.TestAllTypes()
  465. try:
  466. msg.repeated_nested_message.add('wrong')
  467. except TypeError:
  468. pass
  469. try:
  470. msg.repeated_nested_message.add(value_field='wrong')
  471. except ValueError:
  472. pass
  473. self.assertEqual(len(msg.repeated_nested_message), 0)
  474. def testRepeatedContains(self, message_module):
  475. msg = message_module.TestAllTypes()
  476. msg.repeated_int32.extend([1, 2, 3])
  477. self.assertIn(2, msg.repeated_int32)
  478. self.assertNotIn(0, msg.repeated_int32)
  479. msg.repeated_nested_message.add(bb=1)
  480. sub_msg1 = msg.repeated_nested_message[0]
  481. sub_msg2 = message_module.TestAllTypes.NestedMessage(bb=2)
  482. sub_msg3 = message_module.TestAllTypes.NestedMessage(bb=3)
  483. msg.repeated_nested_message.append(sub_msg2)
  484. msg.repeated_nested_message.insert(0, sub_msg3)
  485. self.assertIn(sub_msg1, msg.repeated_nested_message)
  486. self.assertIn(sub_msg2, msg.repeated_nested_message)
  487. self.assertIn(sub_msg3, msg.repeated_nested_message)
  488. def testRepeatedScalarIterable(self, message_module):
  489. msg = message_module.TestAllTypes()
  490. msg.repeated_int32.extend([1, 2, 3])
  491. add = 0
  492. for item in msg.repeated_int32:
  493. add += item
  494. self.assertEqual(add, 6)
  495. def testRepeatedNestedFieldIteration(self, message_module):
  496. msg = message_module.TestAllTypes()
  497. msg.repeated_nested_message.add(bb=1)
  498. msg.repeated_nested_message.add(bb=2)
  499. msg.repeated_nested_message.add(bb=3)
  500. msg.repeated_nested_message.add(bb=4)
  501. self.assertEqual([1, 2, 3, 4],
  502. [m.bb for m in msg.repeated_nested_message])
  503. self.assertEqual([4, 3, 2, 1],
  504. [m.bb for m in reversed(msg.repeated_nested_message)])
  505. self.assertEqual([4, 3, 2, 1],
  506. [m.bb for m in msg.repeated_nested_message[::-1]])
  507. def testSortingRepeatedScalarFieldsDefaultComparator(self, message_module):
  508. """Check some different types with the default comparator."""
  509. message = message_module.TestAllTypes()
  510. # TODO(mattp): would testing more scalar types strengthen test?
  511. message.repeated_int32.append(1)
  512. message.repeated_int32.append(3)
  513. message.repeated_int32.append(2)
  514. message.repeated_int32.sort()
  515. self.assertEqual(message.repeated_int32[0], 1)
  516. self.assertEqual(message.repeated_int32[1], 2)
  517. self.assertEqual(message.repeated_int32[2], 3)
  518. self.assertEqual(str(message.repeated_int32), str([1, 2, 3]))
  519. message.repeated_float.append(1.1)
  520. message.repeated_float.append(1.3)
  521. message.repeated_float.append(1.2)
  522. message.repeated_float.sort()
  523. self.assertAlmostEqual(message.repeated_float[0], 1.1)
  524. self.assertAlmostEqual(message.repeated_float[1], 1.2)
  525. self.assertAlmostEqual(message.repeated_float[2], 1.3)
  526. message.repeated_string.append('a')
  527. message.repeated_string.append('c')
  528. message.repeated_string.append('b')
  529. message.repeated_string.sort()
  530. self.assertEqual(message.repeated_string[0], 'a')
  531. self.assertEqual(message.repeated_string[1], 'b')
  532. self.assertEqual(message.repeated_string[2], 'c')
  533. self.assertEqual(str(message.repeated_string), str([u'a', u'b', u'c']))
  534. message.repeated_bytes.append(b'a')
  535. message.repeated_bytes.append(b'c')
  536. message.repeated_bytes.append(b'b')
  537. message.repeated_bytes.sort()
  538. self.assertEqual(message.repeated_bytes[0], b'a')
  539. self.assertEqual(message.repeated_bytes[1], b'b')
  540. self.assertEqual(message.repeated_bytes[2], b'c')
  541. self.assertEqual(str(message.repeated_bytes), str([b'a', b'b', b'c']))
  542. def testSortingRepeatedScalarFieldsCustomComparator(self, message_module):
  543. """Check some different types with custom comparator."""
  544. message = message_module.TestAllTypes()
  545. message.repeated_int32.append(-3)
  546. message.repeated_int32.append(-2)
  547. message.repeated_int32.append(-1)
  548. message.repeated_int32.sort(key=abs)
  549. self.assertEqual(message.repeated_int32[0], -1)
  550. self.assertEqual(message.repeated_int32[1], -2)
  551. self.assertEqual(message.repeated_int32[2], -3)
  552. message.repeated_string.append('aaa')
  553. message.repeated_string.append('bb')
  554. message.repeated_string.append('c')
  555. message.repeated_string.sort(key=len)
  556. self.assertEqual(message.repeated_string[0], 'c')
  557. self.assertEqual(message.repeated_string[1], 'bb')
  558. self.assertEqual(message.repeated_string[2], 'aaa')
  559. def testSortingRepeatedCompositeFieldsCustomComparator(self, message_module):
  560. """Check passing a custom comparator to sort a repeated composite field."""
  561. message = message_module.TestAllTypes()
  562. message.repeated_nested_message.add().bb = 1
  563. message.repeated_nested_message.add().bb = 3
  564. message.repeated_nested_message.add().bb = 2
  565. message.repeated_nested_message.add().bb = 6
  566. message.repeated_nested_message.add().bb = 5
  567. message.repeated_nested_message.add().bb = 4
  568. message.repeated_nested_message.sort(key=operator.attrgetter('bb'))
  569. self.assertEqual(message.repeated_nested_message[0].bb, 1)
  570. self.assertEqual(message.repeated_nested_message[1].bb, 2)
  571. self.assertEqual(message.repeated_nested_message[2].bb, 3)
  572. self.assertEqual(message.repeated_nested_message[3].bb, 4)
  573. self.assertEqual(message.repeated_nested_message[4].bb, 5)
  574. self.assertEqual(message.repeated_nested_message[5].bb, 6)
  575. self.assertEqual(str(message.repeated_nested_message),
  576. '[bb: 1\n, bb: 2\n, bb: 3\n, bb: 4\n, bb: 5\n, bb: 6\n]')
  577. def testSortingRepeatedCompositeFieldsStable(self, message_module):
  578. """Check passing a custom comparator to sort a repeated composite field."""
  579. message = message_module.TestAllTypes()
  580. message.repeated_nested_message.add().bb = 21
  581. message.repeated_nested_message.add().bb = 20
  582. message.repeated_nested_message.add().bb = 13
  583. message.repeated_nested_message.add().bb = 33
  584. message.repeated_nested_message.add().bb = 11
  585. message.repeated_nested_message.add().bb = 24
  586. message.repeated_nested_message.add().bb = 10
  587. message.repeated_nested_message.sort(key=lambda z: z.bb // 10)
  588. self.assertEqual(
  589. [13, 11, 10, 21, 20, 24, 33],
  590. [n.bb for n in message.repeated_nested_message])
  591. # Make sure that for the C++ implementation, the underlying fields
  592. # are actually reordered.
  593. pb = message.SerializeToString()
  594. message.Clear()
  595. message.MergeFromString(pb)
  596. self.assertEqual(
  597. [13, 11, 10, 21, 20, 24, 33],
  598. [n.bb for n in message.repeated_nested_message])
  599. def testRepeatedCompositeFieldSortArguments(self, message_module):
  600. """Check sorting a repeated composite field using list.sort() arguments."""
  601. message = message_module.TestAllTypes()
  602. get_bb = operator.attrgetter('bb')
  603. cmp_bb = lambda a, b: cmp(a.bb, b.bb)
  604. message.repeated_nested_message.add().bb = 1
  605. message.repeated_nested_message.add().bb = 3
  606. message.repeated_nested_message.add().bb = 2
  607. message.repeated_nested_message.add().bb = 6
  608. message.repeated_nested_message.add().bb = 5
  609. message.repeated_nested_message.add().bb = 4
  610. message.repeated_nested_message.sort(key=get_bb)
  611. self.assertEqual([k.bb for k in message.repeated_nested_message],
  612. [1, 2, 3, 4, 5, 6])
  613. message.repeated_nested_message.sort(key=get_bb, reverse=True)
  614. self.assertEqual([k.bb for k in message.repeated_nested_message],
  615. [6, 5, 4, 3, 2, 1])
  616. if sys.version_info >= (3,): return # No cmp sorting in PY3.
  617. message.repeated_nested_message.sort(sort_function=cmp_bb)
  618. self.assertEqual([k.bb for k in message.repeated_nested_message],
  619. [1, 2, 3, 4, 5, 6])
  620. message.repeated_nested_message.sort(cmp=cmp_bb, reverse=True)
  621. self.assertEqual([k.bb for k in message.repeated_nested_message],
  622. [6, 5, 4, 3, 2, 1])
  623. def testRepeatedScalarFieldSortArguments(self, message_module):
  624. """Check sorting a scalar field using list.sort() arguments."""
  625. message = message_module.TestAllTypes()
  626. message.repeated_int32.append(-3)
  627. message.repeated_int32.append(-2)
  628. message.repeated_int32.append(-1)
  629. message.repeated_int32.sort(key=abs)
  630. self.assertEqual(list(message.repeated_int32), [-1, -2, -3])
  631. message.repeated_int32.sort(key=abs, reverse=True)
  632. self.assertEqual(list(message.repeated_int32), [-3, -2, -1])
  633. if sys.version_info < (3,): # No cmp sorting in PY3.
  634. abs_cmp = lambda a, b: cmp(abs(a), abs(b))
  635. message.repeated_int32.sort(sort_function=abs_cmp)
  636. self.assertEqual(list(message.repeated_int32), [-1, -2, -3])
  637. message.repeated_int32.sort(cmp=abs_cmp, reverse=True)
  638. self.assertEqual(list(message.repeated_int32), [-3, -2, -1])
  639. message.repeated_string.append('aaa')
  640. message.repeated_string.append('bb')
  641. message.repeated_string.append('c')
  642. message.repeated_string.sort(key=len)
  643. self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa'])
  644. message.repeated_string.sort(key=len, reverse=True)
  645. self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c'])
  646. if sys.version_info < (3,): # No cmp sorting in PY3.
  647. len_cmp = lambda a, b: cmp(len(a), len(b))
  648. message.repeated_string.sort(sort_function=len_cmp)
  649. self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa'])
  650. message.repeated_string.sort(cmp=len_cmp, reverse=True)
  651. self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c'])
  652. def testRepeatedFieldsComparable(self, message_module):
  653. m1 = message_module.TestAllTypes()
  654. m2 = message_module.TestAllTypes()
  655. m1.repeated_int32.append(0)
  656. m1.repeated_int32.append(1)
  657. m1.repeated_int32.append(2)
  658. m2.repeated_int32.append(0)
  659. m2.repeated_int32.append(1)
  660. m2.repeated_int32.append(2)
  661. m1.repeated_nested_message.add().bb = 1
  662. m1.repeated_nested_message.add().bb = 2
  663. m1.repeated_nested_message.add().bb = 3
  664. m2.repeated_nested_message.add().bb = 1
  665. m2.repeated_nested_message.add().bb = 2
  666. m2.repeated_nested_message.add().bb = 3
  667. if sys.version_info >= (3,): return # No cmp() in PY3.
  668. # These comparisons should not raise errors.
  669. _ = m1 < m2
  670. _ = m1.repeated_nested_message < m2.repeated_nested_message
  671. # Make sure cmp always works. If it wasn't defined, these would be
  672. # id() comparisons and would all fail.
  673. self.assertEqual(cmp(m1, m2), 0)
  674. self.assertEqual(cmp(m1.repeated_int32, m2.repeated_int32), 0)
  675. self.assertEqual(cmp(m1.repeated_int32, [0, 1, 2]), 0)
  676. self.assertEqual(cmp(m1.repeated_nested_message,
  677. m2.repeated_nested_message), 0)
  678. with self.assertRaises(TypeError):
  679. # Can't compare repeated composite containers to lists.
  680. cmp(m1.repeated_nested_message, m2.repeated_nested_message[:])
  681. # TODO(anuraag): Implement extensiondict comparison in C++ and then add test
  682. def testRepeatedFieldsAreSequences(self, message_module):
  683. m = message_module.TestAllTypes()
  684. self.assertIsInstance(m.repeated_int32, collections_abc.MutableSequence)
  685. self.assertIsInstance(m.repeated_nested_message,
  686. collections_abc.MutableSequence)
  687. def testRepeatedFieldsNotHashable(self, message_module):
  688. m = message_module.TestAllTypes()
  689. with self.assertRaises(TypeError):
  690. hash(m.repeated_int32)
  691. with self.assertRaises(TypeError):
  692. hash(m.repeated_nested_message)
  693. def testRepeatedFieldInsideNestedMessage(self, message_module):
  694. m = message_module.NestedTestAllTypes()
  695. m.payload.repeated_int32.extend([])
  696. self.assertTrue(m.HasField('payload'))
  697. def testMergeFrom(self, message_module):
  698. m1 = message_module.TestAllTypes()
  699. m2 = message_module.TestAllTypes()
  700. # Cpp extension will lazily create a sub message which is immutable.
  701. nested = m1.optional_nested_message
  702. self.assertEqual(0, nested.bb)
  703. m2.optional_nested_message.bb = 1
  704. # Make sure cmessage pointing to a mutable message after merge instead of
  705. # the lazily created message.
  706. m1.MergeFrom(m2)
  707. self.assertEqual(1, nested.bb)
  708. # Test more nested sub message.
  709. msg1 = message_module.NestedTestAllTypes()
  710. msg2 = message_module.NestedTestAllTypes()
  711. nested = msg1.child.payload.optional_nested_message
  712. self.assertEqual(0, nested.bb)
  713. msg2.child.payload.optional_nested_message.bb = 1
  714. msg1.MergeFrom(msg2)
  715. self.assertEqual(1, nested.bb)
  716. # Test repeated field.
  717. self.assertEqual(msg1.payload.repeated_nested_message,
  718. msg1.payload.repeated_nested_message)
  719. nested = msg2.payload.repeated_nested_message.add()
  720. nested.bb = 1
  721. msg1.MergeFrom(msg2)
  722. self.assertEqual(1, len(msg1.payload.repeated_nested_message))
  723. self.assertEqual(1, nested.bb)
  724. def testMergeFromString(self, message_module):
  725. m1 = message_module.TestAllTypes()
  726. m2 = message_module.TestAllTypes()
  727. # Cpp extension will lazily create a sub message which is immutable.
  728. self.assertEqual(0, m1.optional_nested_message.bb)
  729. m2.optional_nested_message.bb = 1
  730. # Make sure cmessage pointing to a mutable message after merge instead of
  731. # the lazily created message.
  732. m1.MergeFromString(m2.SerializeToString())
  733. self.assertEqual(1, m1.optional_nested_message.bb)
  734. def testMergeFromStringUsingMemoryView(self, message_module):
  735. m2 = message_module.TestAllTypes()
  736. m2.optional_string = 'scalar string'
  737. m2.repeated_string.append('repeated string')
  738. m2.optional_bytes = b'scalar bytes'
  739. m2.repeated_bytes.append(b'repeated bytes')
  740. serialized = m2.SerializeToString()
  741. memview = memoryview(serialized)
  742. m1 = message_module.TestAllTypes.FromString(memview)
  743. self.assertEqual(m1.optional_bytes, b'scalar bytes')
  744. self.assertEqual(m1.repeated_bytes, [b'repeated bytes'])
  745. self.assertEqual(m1.optional_string, 'scalar string')
  746. self.assertEqual(m1.repeated_string, ['repeated string'])
  747. # Make sure that the memoryview was correctly converted to bytes, and
  748. # that a sub-sliced memoryview is not being used.
  749. self.assertIsInstance(m1.optional_bytes, bytes)
  750. self.assertIsInstance(m1.repeated_bytes[0], bytes)
  751. self.assertIsInstance(m1.optional_string, six.text_type)
  752. self.assertIsInstance(m1.repeated_string[0], six.text_type)
  753. def testMergeFromEmpty(self, message_module):
  754. m1 = message_module.TestAllTypes()
  755. # Cpp extension will lazily create a sub message which is immutable.
  756. self.assertEqual(0, m1.optional_nested_message.bb)
  757. self.assertFalse(m1.HasField('optional_nested_message'))
  758. # Make sure the sub message is still immutable after merge from empty.
  759. m1.MergeFromString(b'') # field state should not change
  760. self.assertFalse(m1.HasField('optional_nested_message'))
  761. def ensureNestedMessageExists(self, msg, attribute):
  762. """Make sure that a nested message object exists.
  763. As soon as a nested message attribute is accessed, it will be present in the
  764. _fields dict, without being marked as actually being set.
  765. """
  766. getattr(msg, attribute)
  767. self.assertFalse(msg.HasField(attribute))
  768. def testOneofGetCaseNonexistingField(self, message_module):
  769. m = message_module.TestAllTypes()
  770. self.assertRaises(ValueError, m.WhichOneof, 'no_such_oneof_field')
  771. self.assertRaises(Exception, m.WhichOneof, 0)
  772. def testOneofDefaultValues(self, message_module):
  773. m = message_module.TestAllTypes()
  774. self.assertIs(None, m.WhichOneof('oneof_field'))
  775. self.assertFalse(m.HasField('oneof_field'))
  776. self.assertFalse(m.HasField('oneof_uint32'))
  777. # Oneof is set even when setting it to a default value.
  778. m.oneof_uint32 = 0
  779. self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
  780. self.assertTrue(m.HasField('oneof_field'))
  781. self.assertTrue(m.HasField('oneof_uint32'))
  782. self.assertFalse(m.HasField('oneof_string'))
  783. m.oneof_string = ""
  784. self.assertEqual('oneof_string', m.WhichOneof('oneof_field'))
  785. self.assertTrue(m.HasField('oneof_string'))
  786. self.assertFalse(m.HasField('oneof_uint32'))
  787. def testOneofSemantics(self, message_module):
  788. m = message_module.TestAllTypes()
  789. self.assertIs(None, m.WhichOneof('oneof_field'))
  790. m.oneof_uint32 = 11
  791. self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
  792. self.assertTrue(m.HasField('oneof_uint32'))
  793. m.oneof_string = u'foo'
  794. self.assertEqual('oneof_string', m.WhichOneof('oneof_field'))
  795. self.assertFalse(m.HasField('oneof_uint32'))
  796. self.assertTrue(m.HasField('oneof_string'))
  797. # Read nested message accessor without accessing submessage.
  798. m.oneof_nested_message
  799. self.assertEqual('oneof_string', m.WhichOneof('oneof_field'))
  800. self.assertTrue(m.HasField('oneof_string'))
  801. self.assertFalse(m.HasField('oneof_nested_message'))
  802. # Read accessor of nested message without accessing submessage.
  803. m.oneof_nested_message.bb
  804. self.assertEqual('oneof_string', m.WhichOneof('oneof_field'))
  805. self.assertTrue(m.HasField('oneof_string'))
  806. self.assertFalse(m.HasField('oneof_nested_message'))
  807. m.oneof_nested_message.bb = 11
  808. self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field'))
  809. self.assertFalse(m.HasField('oneof_string'))
  810. self.assertTrue(m.HasField('oneof_nested_message'))
  811. m.oneof_bytes = b'bb'
  812. self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field'))
  813. self.assertFalse(m.HasField('oneof_nested_message'))
  814. self.assertTrue(m.HasField('oneof_bytes'))
  815. def testOneofCompositeFieldReadAccess(self, message_module):
  816. m = message_module.TestAllTypes()
  817. m.oneof_uint32 = 11
  818. self.ensureNestedMessageExists(m, 'oneof_nested_message')
  819. self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
  820. self.assertEqual(11, m.oneof_uint32)
  821. def testOneofWhichOneof(self, message_module):
  822. m = message_module.TestAllTypes()
  823. self.assertIs(None, m.WhichOneof('oneof_field'))
  824. if message_module is unittest_pb2:
  825. self.assertFalse(m.HasField('oneof_field'))
  826. m.oneof_uint32 = 11
  827. self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
  828. if message_module is unittest_pb2:
  829. self.assertTrue(m.HasField('oneof_field'))
  830. m.oneof_bytes = b'bb'
  831. self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field'))
  832. m.ClearField('oneof_bytes')
  833. self.assertIs(None, m.WhichOneof('oneof_field'))
  834. if message_module is unittest_pb2:
  835. self.assertFalse(m.HasField('oneof_field'))
  836. def testOneofClearField(self, message_module):
  837. m = message_module.TestAllTypes()
  838. m.oneof_uint32 = 11
  839. m.ClearField('oneof_field')
  840. if message_module is unittest_pb2:
  841. self.assertFalse(m.HasField('oneof_field'))
  842. self.assertFalse(m.HasField('oneof_uint32'))
  843. self.assertIs(None, m.WhichOneof('oneof_field'))
  844. def testOneofClearSetField(self, message_module):
  845. m = message_module.TestAllTypes()
  846. m.oneof_uint32 = 11
  847. m.ClearField('oneof_uint32')
  848. if message_module is unittest_pb2:
  849. self.assertFalse(m.HasField('oneof_field'))
  850. self.assertFalse(m.HasField('oneof_uint32'))
  851. self.assertIs(None, m.WhichOneof('oneof_field'))
  852. def testOneofClearUnsetField(self, message_module):
  853. m = message_module.TestAllTypes()
  854. m.oneof_uint32 = 11
  855. self.ensureNestedMessageExists(m, 'oneof_nested_message')
  856. m.ClearField('oneof_nested_message')
  857. self.assertEqual(11, m.oneof_uint32)
  858. if message_module is unittest_pb2:
  859. self.assertTrue(m.HasField('oneof_field'))
  860. self.assertTrue(m.HasField('oneof_uint32'))
  861. self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
  862. def testOneofDeserialize(self, message_module):
  863. m = message_module.TestAllTypes()
  864. m.oneof_uint32 = 11
  865. m2 = message_module.TestAllTypes()
  866. m2.ParseFromString(m.SerializeToString())
  867. self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field'))
  868. def testOneofCopyFrom(self, message_module):
  869. m = message_module.TestAllTypes()
  870. m.oneof_uint32 = 11
  871. m2 = message_module.TestAllTypes()
  872. m2.CopyFrom(m)
  873. self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field'))
  874. def testOneofNestedMergeFrom(self, message_module):
  875. m = message_module.NestedTestAllTypes()
  876. m.payload.oneof_uint32 = 11
  877. m2 = message_module.NestedTestAllTypes()
  878. m2.payload.oneof_bytes = b'bb'
  879. m2.child.payload.oneof_bytes = b'bb'
  880. m2.MergeFrom(m)
  881. self.assertEqual('oneof_uint32', m2.payload.WhichOneof('oneof_field'))
  882. self.assertEqual('oneof_bytes', m2.child.payload.WhichOneof('oneof_field'))
  883. def testOneofMessageMergeFrom(self, message_module):
  884. m = message_module.NestedTestAllTypes()
  885. m.payload.oneof_nested_message.bb = 11
  886. m.child.payload.oneof_nested_message.bb = 12
  887. m2 = message_module.NestedTestAllTypes()
  888. m2.payload.oneof_uint32 = 13
  889. m2.MergeFrom(m)
  890. self.assertEqual('oneof_nested_message',
  891. m2.payload.WhichOneof('oneof_field'))
  892. self.assertEqual('oneof_nested_message',
  893. m2.child.payload.WhichOneof('oneof_field'))
  894. def testOneofNestedMessageInit(self, message_module):
  895. m = message_module.TestAllTypes(
  896. oneof_nested_message=message_module.TestAllTypes.NestedMessage())
  897. self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field'))
  898. def testOneofClear(self, message_module):
  899. m = message_module.TestAllTypes()
  900. m.oneof_uint32 = 11
  901. m.Clear()
  902. self.assertIsNone(m.WhichOneof('oneof_field'))
  903. m.oneof_bytes = b'bb'
  904. self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field'))
  905. def testAssignByteStringToUnicodeField(self, message_module):
  906. """Assigning a byte string to a string field should result
  907. in the value being converted to a Unicode string."""
  908. m = message_module.TestAllTypes()
  909. m.optional_string = str('')
  910. self.assertIsInstance(m.optional_string, six.text_type)
  911. def testLongValuedSlice(self, message_module):
  912. """It should be possible to use long-valued indices in slices.
  913. This didn't used to work in the v2 C++ implementation.
  914. """
  915. m = message_module.TestAllTypes()
  916. # Repeated scalar
  917. m.repeated_int32.append(1)
  918. sl = m.repeated_int32[long(0):long(len(m.repeated_int32))]
  919. self.assertEqual(len(m.repeated_int32), len(sl))
  920. # Repeated composite
  921. m.repeated_nested_message.add().bb = 3
  922. sl = m.repeated_nested_message[long(0):long(len(m.repeated_nested_message))]
  923. self.assertEqual(len(m.repeated_nested_message), len(sl))
  924. def testExtendShouldNotSwallowExceptions(self, message_module):
  925. """This didn't use to work in the v2 C++ implementation."""
  926. m = message_module.TestAllTypes()
  927. with self.assertRaises(NameError) as _:
  928. m.repeated_int32.extend(a for i in range(10)) # pylint: disable=undefined-variable
  929. with self.assertRaises(NameError) as _:
  930. m.repeated_nested_enum.extend(
  931. a for i in range(10)) # pylint: disable=undefined-variable
  932. FALSY_VALUES = [None, False, 0, 0.0, b'', u'', bytearray(), [], {}, set()]
  933. def testExtendInt32WithNothing(self, message_module):
  934. """Test no-ops extending repeated int32 fields."""
  935. m = message_module.TestAllTypes()
  936. self.assertSequenceEqual([], m.repeated_int32)
  937. # TODO(ptucker): Deprecate this behavior. b/18413862
  938. for falsy_value in MessageTest.FALSY_VALUES:
  939. m.repeated_int32.extend(falsy_value)
  940. self.assertSequenceEqual([], m.repeated_int32)
  941. m.repeated_int32.extend([])
  942. self.assertSequenceEqual([], m.repeated_int32)
  943. def testExtendFloatWithNothing(self, message_module):
  944. """Test no-ops extending repeated float fields."""
  945. m = message_module.TestAllTypes()
  946. self.assertSequenceEqual([], m.repeated_float)
  947. # TODO(ptucker): Deprecate this behavior. b/18413862
  948. for falsy_value in MessageTest.FALSY_VALUES:
  949. m.repeated_float.extend(falsy_value)
  950. self.assertSequenceEqual([], m.repeated_float)
  951. m.repeated_float.extend([])
  952. self.assertSequenceEqual([], m.repeated_float)
  953. def testExtendStringWithNothing(self, message_module):
  954. """Test no-ops extending repeated string fields."""
  955. m = message_module.TestAllTypes()
  956. self.assertSequenceEqual([], m.repeated_string)
  957. # TODO(ptucker): Deprecate this behavior. b/18413862
  958. for falsy_value in MessageTest.FALSY_VALUES:
  959. m.repeated_string.extend(falsy_value)
  960. self.assertSequenceEqual([], m.repeated_string)
  961. m.repeated_string.extend([])
  962. self.assertSequenceEqual([], m.repeated_string)
  963. def testExtendInt32WithPythonList(self, message_module):
  964. """Test extending repeated int32 fields with python lists."""
  965. m = message_module.TestAllTypes()
  966. self.assertSequenceEqual([], m.repeated_int32)
  967. m.repeated_int32.extend([0])
  968. self.assertSequenceEqual([0], m.repeated_int32)
  969. m.repeated_int32.extend([1, 2])
  970. self.assertSequenceEqual([0, 1, 2], m.repeated_int32)
  971. m.repeated_int32.extend([3, 4])
  972. self.assertSequenceEqual([0, 1, 2, 3, 4], m.repeated_int32)
  973. def testExtendFloatWithPythonList(self, message_module):
  974. """Test extending repeated float fields with python lists."""
  975. m = message_module.TestAllTypes()
  976. self.assertSequenceEqual([], m.repeated_float)
  977. m.repeated_float.extend([0.0])
  978. self.assertSequenceEqual([0.0], m.repeated_float)
  979. m.repeated_float.extend([1.0, 2.0])
  980. self.assertSequenceEqual([0.0, 1.0, 2.0], m.repeated_float)
  981. m.repeated_float.extend([3.0, 4.0])
  982. self.assertSequenceEqual([0.0, 1.0, 2.0, 3.0, 4.0], m.repeated_float)
  983. def testExtendStringWithPythonList(self, message_module):
  984. """Test extending repeated string fields with python lists."""
  985. m = message_module.TestAllTypes()
  986. self.assertSequenceEqual([], m.repeated_string)
  987. m.repeated_string.extend([''])
  988. self.assertSequenceEqual([''], m.repeated_string)
  989. m.repeated_string.extend(['11', '22'])
  990. self.assertSequenceEqual(['', '11', '22'], m.repeated_string)
  991. m.repeated_string.extend(['33', '44'])
  992. self.assertSequenceEqual(['', '11', '22', '33', '44'], m.repeated_string)
  993. def testExtendStringWithString(self, message_module):
  994. """Test extending repeated string fields with characters from a string."""
  995. m = message_module.TestAllTypes()
  996. self.assertSequenceEqual([], m.repeated_string)
  997. m.repeated_string.extend('abc')
  998. self.assertSequenceEqual(['a', 'b', 'c'], m.repeated_string)
  999. class TestIterable(object):
  1000. """This iterable object mimics the behavior of numpy.array.
  1001. __nonzero__ fails for length > 1, and returns bool(item[0]) for length == 1.
  1002. """
  1003. def __init__(self, values=None):
  1004. self._list = values or []
  1005. def __nonzero__(self):
  1006. size = len(self._list)
  1007. if size == 0:
  1008. return False
  1009. if size == 1:
  1010. return bool(self._list[0])
  1011. raise ValueError('Truth value is ambiguous.')
  1012. def __len__(self):
  1013. return len(self._list)
  1014. def __iter__(self):
  1015. return self._list.__iter__()
  1016. def testExtendInt32WithIterable(self, message_module):
  1017. """Test extending repeated int32 fields with iterable."""
  1018. m = message_module.TestAllTypes()
  1019. self.assertSequenceEqual([], m.repeated_int32)
  1020. m.repeated_int32.extend(MessageTest.TestIterable([]))
  1021. self.assertSequenceEqual([], m.repeated_int32)
  1022. m.repeated_int32.extend(MessageTest.TestIterable([0]))
  1023. self.assertSequenceEqual([0], m.repeated_int32)
  1024. m.repeated_int32.extend(MessageTest.TestIterable([1, 2]))
  1025. self.assertSequenceEqual([0, 1, 2], m.repeated_int32)
  1026. m.repeated_int32.extend(MessageTest.TestIterable([3, 4]))
  1027. self.assertSequenceEqual([0, 1, 2, 3, 4], m.repeated_int32)
  1028. def testExtendFloatWithIterable(self, message_module):
  1029. """Test extending repeated float fields with iterable."""
  1030. m = message_module.TestAllTypes()
  1031. self.assertSequenceEqual([], m.repeated_float)
  1032. m.repeated_float.extend(MessageTest.TestIterable([]))
  1033. self.assertSequenceEqual([], m.repeated_float)
  1034. m.repeated_float.extend(MessageTest.TestIterable([0.0]))
  1035. self.assertSequenceEqual([0.0], m.repeated_float)
  1036. m.repeated_float.extend(MessageTest.TestIterable([1.0, 2.0]))
  1037. self.assertSequenceEqual([0.0, 1.0, 2.0], m.repeated_float)
  1038. m.repeated_float.extend(MessageTest.TestIterable([3.0, 4.0]))
  1039. self.assertSequenceEqual([0.0, 1.0, 2.0, 3.0, 4.0], m.repeated_float)
  1040. def testExtendStringWithIterable(self, message_module):
  1041. """Test extending repeated string fields with iterable."""
  1042. m = message_module.TestAllTypes()
  1043. self.assertSequenceEqual([], m.repeated_string)
  1044. m.repeated_string.extend(MessageTest.TestIterable([]))
  1045. self.assertSequenceEqual([], m.repeated_string)
  1046. m.repeated_string.extend(MessageTest.TestIterable(['']))
  1047. self.assertSequenceEqual([''], m.repeated_string)
  1048. m.repeated_string.extend(MessageTest.TestIterable(['1', '2']))
  1049. self.assertSequenceEqual(['', '1', '2'], m.repeated_string)
  1050. m.repeated_string.extend(MessageTest.TestIterable(['3', '4']))
  1051. self.assertSequenceEqual(['', '1', '2', '3', '4'], m.repeated_string)
  1052. def testPickleRepeatedScalarContainer(self, message_module):
  1053. # TODO(tibell): The pure-Python implementation support pickling of
  1054. # scalar containers in *some* cases. For now the cpp2 version
  1055. # throws an exception to avoid a segfault. Investigate if we
  1056. # want to support pickling of these fields.
  1057. #
  1058. # For more information see: https://b2.corp.google.com/u/0/issues/18677897
  1059. if (api_implementation.Type() != 'cpp' or
  1060. api_implementation.Version() == 2):
  1061. return
  1062. m = message_module.TestAllTypes()
  1063. with self.assertRaises(pickle.PickleError) as _:
  1064. pickle.dumps(m.repeated_int32, pickle.HIGHEST_PROTOCOL)
  1065. def testSortEmptyRepeatedCompositeContainer(self, message_module):
  1066. """Exercise a scenario that has led to segfaults in the past.
  1067. """
  1068. m = message_module.TestAllTypes()
  1069. m.repeated_nested_message.sort()
  1070. def testHasFieldOnRepeatedField(self, message_module):
  1071. """Using HasField on a repeated field should raise an exception.
  1072. """
  1073. m = message_module.TestAllTypes()
  1074. with self.assertRaises(ValueError) as _:
  1075. m.HasField('repeated_int32')
  1076. def testRepeatedScalarFieldPop(self, message_module):
  1077. m = message_module.TestAllTypes()
  1078. with self.assertRaises(IndexError) as _:
  1079. m.repeated_int32.pop()
  1080. m.repeated_int32.extend(range(5))
  1081. self.assertEqual(4, m.repeated_int32.pop())
  1082. self.assertEqual(0, m.repeated_int32.pop(0))
  1083. self.assertEqual(2, m.repeated_int32.pop(1))
  1084. self.assertEqual([1, 3], m.repeated_int32)
  1085. def testRepeatedCompositeFieldPop(self, message_module):
  1086. m = message_module.TestAllTypes()
  1087. with self.assertRaises(IndexError) as _:
  1088. m.repeated_nested_message.pop()
  1089. with self.assertRaises(TypeError) as _:
  1090. m.repeated_nested_message.pop('0')
  1091. for i in range(5):
  1092. n = m.repeated_nested_message.add()
  1093. n.bb = i
  1094. self.assertEqual(4, m.repeated_nested_message.pop().bb)
  1095. self.assertEqual(0, m.repeated_nested_message.pop(0).bb)
  1096. self.assertEqual(2, m.repeated_nested_message.pop(1).bb)
  1097. self.assertEqual([1, 3], [n.bb for n in m.repeated_nested_message])
  1098. def testRepeatedCompareWithSelf(self, message_module):
  1099. m = message_module.TestAllTypes()
  1100. for i in range(5):
  1101. m.repeated_int32.insert(i, i)
  1102. n = m.repeated_nested_message.add()
  1103. n.bb = i
  1104. self.assertSequenceEqual(m.repeated_int32, m.repeated_int32)
  1105. self.assertEqual(m.repeated_nested_message, m.repeated_nested_message)
  1106. def testReleasedNestedMessages(self, message_module):
  1107. """A case that lead to a segfault when a message detached from its parent
  1108. container has itself a child container.
  1109. """
  1110. m = message_module.NestedTestAllTypes()
  1111. m = m.repeated_child.add()
  1112. m = m.child
  1113. m = m.repeated_child.add()
  1114. self.assertEqual(m.payload.optional_int32, 0)
  1115. def testSetRepeatedComposite(self, message_module):
  1116. m = message_module.TestAllTypes()
  1117. with self.assertRaises(AttributeError):
  1118. m.repeated_int32 = []
  1119. m.repeated_int32.append(1)
  1120. with self.assertRaises(AttributeError):
  1121. m.repeated_int32 = []
  1122. def testReturningType(self, message_module):
  1123. m = message_module.TestAllTypes()
  1124. self.assertEqual(float, type(m.optional_float))
  1125. self.assertEqual(float, type(m.optional_double))
  1126. self.assertEqual(bool, type(m.optional_bool))
  1127. m.optional_float = 1
  1128. m.optional_double = 1
  1129. m.optional_bool = 1
  1130. m.repeated_float.append(1)
  1131. m.repeated_double.append(1)
  1132. m.repeated_bool.append(1)
  1133. m.ParseFromString(m.SerializeToString())
  1134. self.assertEqual(float, type(m.optional_float))
  1135. self.assertEqual(float, type(m.optional_double))
  1136. self.assertEqual('1.0', str(m.optional_double))
  1137. self.assertEqual(bool, type(m.optional_bool))
  1138. self.assertEqual(float, type(m.repeated_float[0]))
  1139. self.assertEqual(float, type(m.repeated_double[0]))
  1140. self.assertEqual(bool, type(m.repeated_bool[0]))
  1141. self.assertEqual(True, m.repeated_bool[0])
  1142. # Class to test proto2-only features (required, extensions, etc.)
  1143. @testing_refleaks.TestCase
  1144. class Proto2Test(unittest.TestCase):
  1145. def testFieldPresence(self):
  1146. message = unittest_pb2.TestAllTypes()
  1147. self.assertFalse(message.HasField("optional_int32"))
  1148. self.assertFalse(message.HasField("optional_bool"))
  1149. self.assertFalse(message.HasField("optional_nested_message"))
  1150. with self.assertRaises(ValueError):
  1151. message.HasField("field_doesnt_exist")
  1152. with self.assertRaises(ValueError):
  1153. message.HasField("repeated_int32")
  1154. with self.assertRaises(ValueError):
  1155. message.HasField("repeated_nested_message")
  1156. self.assertEqual(0, message.optional_int32)
  1157. self.assertEqual(False, message.optional_bool)
  1158. self.assertEqual(0, message.optional_nested_message.bb)
  1159. # Fields are set even when setting the values to default values.
  1160. message.optional_int32 = 0
  1161. message.optional_bool = False
  1162. message.optional_nested_message.bb = 0
  1163. self.assertTrue(message.HasField("optional_int32"))
  1164. self.assertTrue(message.HasField("optional_bool"))
  1165. self.assertTrue(message.HasField("optional_nested_message"))
  1166. # Set the fields to non-default values.
  1167. message.optional_int32 = 5
  1168. message.optional_bool = True
  1169. message.optional_nested_message.bb = 15
  1170. self.assertTrue(message.HasField(u"optional_int32"))
  1171. self.assertTrue(message.HasField("optional_bool"))
  1172. self.assertTrue(message.HasField("optional_nested_message"))
  1173. # Clearing the fields unsets them and resets their value to default.
  1174. message.ClearField("optional_int32")
  1175. message.ClearField(u"optional_bool")
  1176. message.ClearField("optional_nested_message")
  1177. self.assertFalse(message.HasField("optional_int32"))
  1178. self.assertFalse(message.HasField("optional_bool"))
  1179. self.assertFalse(message.HasField("optional_nested_message"))
  1180. self.assertEqual(0, message.optional_int32)
  1181. self.assertEqual(False, message.optional_bool)
  1182. self.assertEqual(0, message.optional_nested_message.bb)
  1183. def testAssignInvalidEnum(self):
  1184. """Assigning an invalid enum number is not allowed in proto2."""
  1185. m = unittest_pb2.TestAllTypes()
  1186. # Proto2 can not assign unknown enum.
  1187. with self.assertRaises(ValueError) as _:
  1188. m.optional_nested_enum = 1234567
  1189. self.assertRaises(ValueError, m.repeated_nested_enum.append, 1234567)
  1190. # Assignment is a different code path than append for the C++ impl.
  1191. m.repeated_nested_enum.append(2)
  1192. m.repeated_nested_enum[0] = 2
  1193. with self.assertRaises(ValueError):
  1194. m.repeated_nested_enum[0] = 123456
  1195. # Unknown enum value can be parsed but is ignored.
  1196. m2 = unittest_proto3_arena_pb2.TestAllTypes()
  1197. m2.optional_nested_enum = 1234567
  1198. m2.repeated_nested_enum.append(7654321)
  1199. serialized = m2.SerializeToString()
  1200. m3 = unittest_pb2.TestAllTypes()
  1201. m3.ParseFromString(serialized)
  1202. self.assertFalse(m3.HasField('optional_nested_enum'))
  1203. # 1 is the default value for optional_nested_enum.
  1204. self.assertEqual(1, m3.optional_nested_enum)
  1205. self.assertEqual(0, len(m3.repeated_nested_enum))
  1206. m2.Clear()
  1207. m2.ParseFromString(m3.SerializeToString())
  1208. self.assertEqual(1234567, m2.optional_nested_enum)
  1209. self.assertEqual(7654321, m2.repeated_nested_enum[0])
  1210. def testUnknownEnumMap(self):
  1211. m = map_proto2_unittest_pb2.TestEnumMap()
  1212. m.known_map_field[123] = 0
  1213. with self.assertRaises(ValueError):
  1214. m.unknown_map_field[1] = 123
  1215. def testExtensionsErrors(self):
  1216. msg = unittest_pb2.TestAllTypes()
  1217. self.assertRaises(AttributeError, getattr, msg, 'Extensions')
  1218. def testMergeFromExtensions(self):
  1219. msg1 = more_extensions_pb2.TopLevelMessage()
  1220. msg2 = more_extensions_pb2.TopLevelMessage()
  1221. # Cpp extension will lazily create a sub message which is immutable.
  1222. self.assertEqual(0, msg1.submessage.Extensions[
  1223. more_extensions_pb2.optional_int_extension])
  1224. self.assertFalse(msg1.HasField('submessage'))
  1225. msg2.submessage.Extensions[
  1226. more_extensions_pb2.optional_int_extension] = 123
  1227. # Make sure cmessage and extensions pointing to a mutable message
  1228. # after merge instead of the lazily created message.
  1229. msg1.MergeFrom(msg2)
  1230. self.assertEqual(123, msg1.submessage.Extensions[
  1231. more_extensions_pb2.optional_int_extension])
  1232. def testGoldenExtensions(self):
  1233. golden_data = test_util.GoldenFileData('golden_message')
  1234. golden_message = unittest_pb2.TestAllExtensions()
  1235. golden_message.ParseFromString(golden_data)
  1236. all_set = unittest_pb2.TestAllExtensions()
  1237. test_util.SetAllExtensions(all_set)
  1238. self.assertEqual(all_set, golden_message)
  1239. self.assertEqual(golden_data, golden_message.SerializeToString())
  1240. golden_copy = copy.deepcopy(golden_message)
  1241. self.assertEqual(golden_data, golden_copy.SerializeToString())
  1242. def testGoldenPackedExtensions(self):
  1243. golden_data = test_util.GoldenFileData('golden_packed_fields_message')
  1244. golden_message = unittest_pb2.TestPackedExtensions()
  1245. golden_message.ParseFromString(golden_data)
  1246. all_set = unittest_pb2.TestPackedExtensions()
  1247. test_util.SetAllPackedExtensions(all_set)
  1248. self.assertEqual(all_set, golden_message)
  1249. self.assertEqual(golden_data, all_set.SerializeToString())
  1250. golden_copy = copy.deepcopy(golden_message)
  1251. self.assertEqual(golden_data, golden_copy.SerializeToString())
  1252. def testPickleIncompleteProto(self):
  1253. golden_message = unittest_pb2.TestRequired(a=1)
  1254. pickled_message = pickle.dumps(golden_message)
  1255. unpickled_message = pickle.loads(pickled_message)
  1256. self.assertEqual(unpickled_message, golden_message)
  1257. self.assertEqual(unpickled_message.a, 1)
  1258. # This is still an incomplete proto - so serializing should fail
  1259. self.assertRaises(message.EncodeError, unpickled_message.SerializeToString)
  1260. # TODO(haberman): this isn't really a proto2-specific test except that this
  1261. # message has a required field in it. Should probably be factored out so
  1262. # that we can test the other parts with proto3.
  1263. def testParsingMerge(self):
  1264. """Check the merge behavior when a required or optional field appears
  1265. multiple times in the input."""
  1266. messages = [
  1267. unittest_pb2.TestAllTypes(),
  1268. unittest_pb2.TestAllTypes(),
  1269. unittest_pb2.TestAllTypes() ]
  1270. messages[0].optional_int32 = 1
  1271. messages[1].optional_int64 = 2
  1272. messages[2].optional_int32 = 3
  1273. messages[2].optional_string = 'hello'
  1274. merged_message = unittest_pb2.TestAllTypes()
  1275. merged_message.optional_int32 = 3
  1276. merged_message.optional_int64 = 2
  1277. merged_message.optional_string = 'hello'
  1278. generator = unittest_pb2.TestParsingMerge.RepeatedFieldsGenerator()
  1279. generator.field1.extend(messages)
  1280. generator.field2.extend(messages)
  1281. generator.field3.extend(messages)
  1282. generator.ext1.extend(messages)
  1283. generator.ext2.extend(messages)
  1284. generator.group1.add().field1.MergeFrom(messages[0])
  1285. generator.group1.add().field1.MergeFrom(messages[1])
  1286. generator.group1.add().field1.MergeFrom(messages[2])
  1287. generator.group2.add().field1.MergeFrom(messages[0])
  1288. generator.group2.add().field1.MergeFrom(messages[1])
  1289. generator.group2.add().field1.MergeFrom(messages[2])
  1290. data = generator.SerializeToString()
  1291. parsing_merge = unittest_pb2.TestParsingMerge()
  1292. parsing_merge.ParseFromString(data)
  1293. # Required and optional fields should be merged.
  1294. self.assertEqual(parsing_merge.required_all_types, merged_message)
  1295. self.assertEqual(parsing_merge.optional_all_types, merged_message)
  1296. self.assertEqual(parsing_merge.optionalgroup.optional_group_all_types,
  1297. merged_message)
  1298. self.assertEqual(parsing_merge.Extensions[
  1299. unittest_pb2.TestParsingMerge.optional_ext],
  1300. merged_message)
  1301. # Repeated fields should not be merged.
  1302. self.assertEqual(len(parsing_merge.repeated_all_types), 3)
  1303. self.assertEqual(len(parsing_merge.repeatedgroup), 3)
  1304. self.assertEqual(len(parsing_merge.Extensions[
  1305. unittest_pb2.TestParsingMerge.repeated_ext]), 3)
  1306. def testPythonicInit(self):
  1307. message = unittest_pb2.TestAllTypes(
  1308. optional_int32=100,
  1309. optional_fixed32=200,
  1310. optional_float=300.5,
  1311. optional_bytes=b'x',
  1312. optionalgroup={'a': 400},
  1313. optional_nested_message={'bb': 500},
  1314. optional_foreign_message={},
  1315. optional_nested_enum='BAZ',
  1316. repeatedgroup=[{'a': 600},
  1317. {'a': 700}],
  1318. repeated_nested_enum=['FOO', unittest_pb2.TestAllTypes.BAR],
  1319. default_int32=800,
  1320. oneof_string='y')
  1321. self.assertIsInstance(message, unittest_pb2.TestAllTypes)
  1322. self.assertEqual(100, message.optional_int32)
  1323. self.assertEqual(200, message.optional_fixed32)
  1324. self.assertEqual(300.5, message.optional_float)
  1325. self.assertEqual(b'x', message.optional_bytes)
  1326. self.assertEqual(400, message.optionalgroup.a)
  1327. self.assertIsInstance(message.optional_nested_message,
  1328. unittest_pb2.TestAllTypes.NestedMessage)
  1329. self.assertEqual(500, message.optional_nested_message.bb)
  1330. self.assertTrue(message.HasField('optional_foreign_message'))
  1331. self.assertEqual(message.optional_foreign_message,
  1332. unittest_pb2.ForeignMessage())
  1333. self.assertEqual(unittest_pb2.TestAllTypes.BAZ,
  1334. message.optional_nested_enum)
  1335. self.assertEqual(2, len(message.repeatedgroup))
  1336. self.assertEqual(600, message.repeatedgroup[0].a)
  1337. self.assertEqual(700, message.repeatedgroup[1].a)
  1338. self.assertEqual(2, len(message.repeated_nested_enum))
  1339. self.assertEqual(unittest_pb2.TestAllTypes.FOO,
  1340. message.repeated_nested_enum[0])
  1341. self.assertEqual(unittest_pb2.TestAllTypes.BAR,
  1342. message.repeated_nested_enum[1])
  1343. self.assertEqual(800, message.default_int32)
  1344. self.assertEqual('y', message.oneof_string)
  1345. self.assertFalse(message.HasField('optional_int64'))
  1346. self.assertEqual(0, len(message.repeated_float))
  1347. self.assertEqual(42, message.default_int64)
  1348. message = unittest_pb2.TestAllTypes(optional_nested_enum=u'BAZ')
  1349. self.assertEqual(unittest_pb2.TestAllTypes.BAZ,
  1350. message.optional_nested_enum)
  1351. with self.assertRaises(ValueError):
  1352. unittest_pb2.TestAllTypes(
  1353. optional_nested_message={'INVALID_NESTED_FIELD': 17})
  1354. with self.assertRaises(TypeError):
  1355. unittest_pb2.TestAllTypes(
  1356. optional_nested_message={'bb': 'INVALID_VALUE_TYPE'})
  1357. with self.assertRaises(ValueError):
  1358. unittest_pb2.TestAllTypes(optional_nested_enum='INVALID_LABEL')
  1359. with self.assertRaises(ValueError):
  1360. unittest_pb2.TestAllTypes(repeated_nested_enum='FOO')
  1361. def testPythonicInitWithDict(self):
  1362. # Both string/unicode field name keys should work.
  1363. kwargs = {
  1364. 'optional_int32': 100,
  1365. u'optional_fixed32': 200,
  1366. }
  1367. msg = unittest_pb2.TestAllTypes(**kwargs)
  1368. self.assertEqual(100, msg.optional_int32)
  1369. self.assertEqual(200, msg.optional_fixed32)
  1370. def test_documentation(self):
  1371. # Also used by the interactive help() function.
  1372. doc = pydoc.html.document(unittest_pb2.TestAllTypes, 'message')
  1373. self.assertIn('class TestAllTypes', doc)
  1374. self.assertIn('SerializePartialToString', doc)
  1375. self.assertIn('repeated_float', doc)
  1376. base = unittest_pb2.TestAllTypes.__bases__[0]
  1377. self.assertRaises(AttributeError, getattr, base, '_extensions_by_name')
  1378. # Class to test proto3-only features/behavior (updated field presence & enums)
  1379. @testing_refleaks.TestCase
  1380. class Proto3Test(unittest.TestCase):
  1381. # Utility method for comparing equality with a map.
  1382. def assertMapIterEquals(self, map_iter, dict_value):
  1383. # Avoid mutating caller's copy.
  1384. dict_value = dict(dict_value)
  1385. for k, v in map_iter:
  1386. self.assertEqual(v, dict_value[k])
  1387. del dict_value[k]
  1388. self.assertEqual({}, dict_value)
  1389. def testFieldPresence(self):
  1390. message = unittest_proto3_arena_pb2.TestAllTypes()
  1391. # We can't test presence of non-repeated, non-submessage fields.
  1392. with self.assertRaises(ValueError):
  1393. message.HasField('optional_int32')
  1394. with self.assertRaises(ValueError):
  1395. message.HasField('optional_float')
  1396. with self.assertRaises(ValueError):
  1397. message.HasField('optional_string')
  1398. with self.assertRaises(ValueError):
  1399. message.HasField('optional_bool')
  1400. # But we can still test presence of submessage fields.
  1401. self.assertFalse(message.HasField('optional_nested_message'))
  1402. # As with proto2, we can't test presence of fields that don't exist, or
  1403. # repeated fields.
  1404. with self.assertRaises(ValueError):
  1405. message.HasField('field_doesnt_exist')
  1406. with self.assertRaises(ValueError):
  1407. message.HasField('repeated_int32')
  1408. with self.assertRaises(ValueError):
  1409. message.HasField('repeated_nested_message')
  1410. # Fields should default to their type-specific default.
  1411. self.assertEqual(0, message.optional_int32)
  1412. self.assertEqual(0, message.optional_float)
  1413. self.assertEqual('', message.optional_string)
  1414. self.assertEqual(False, message.optional_bool)
  1415. self.assertEqual(0, message.optional_nested_message.bb)
  1416. # Setting a submessage should still return proper presence information.
  1417. message.optional_nested_message.bb = 0
  1418. self.assertTrue(message.HasField('optional_nested_message'))
  1419. # Set the fields to non-default values.
  1420. message.optional_int32 = 5
  1421. message.optional_float = 1.1
  1422. message.optional_string = 'abc'
  1423. message.optional_bool = True
  1424. message.optional_nested_message.bb = 15
  1425. # Clearing the fields unsets them and resets their value to default.
  1426. message.ClearField('optional_int32')
  1427. message.ClearField('optional_float')
  1428. message.ClearField('optional_string')
  1429. message.ClearField('optional_bool')
  1430. message.ClearField('optional_nested_message')
  1431. self.assertEqual(0, message.optional_int32)
  1432. self.assertEqual(0, message.optional_float)
  1433. self.assertEqual('', message.optional_string)
  1434. self.assertEqual(False, message.optional_bool)
  1435. self.assertEqual(0, message.optional_nested_message.bb)
  1436. def testProto3ParserDropDefaultScalar(self):
  1437. message_proto2 = unittest_pb2.TestAllTypes()
  1438. message_proto2.optional_int32 = 0
  1439. message_proto2.optional_string = ''
  1440. message_proto2.optional_bytes = b''
  1441. self.assertEqual(len(message_proto2.ListFields()), 3)
  1442. message_proto3 = unittest_proto3_arena_pb2.TestAllTypes()
  1443. message_proto3.ParseFromString(message_proto2.SerializeToString())
  1444. self.assertEqual(len(message_proto3.ListFields()), 0)
  1445. def testProto3Optional(self):
  1446. msg = test_proto3_optional_pb2.TestProto3Optional()
  1447. self.assertFalse(msg.HasField('optional_int32'))
  1448. self.assertFalse(msg.HasField('optional_float'))
  1449. self.assertFalse(msg.HasField('optional_string'))
  1450. self.assertFalse(msg.HasField('optional_nested_message'))
  1451. self.assertFalse(msg.optional_nested_message.HasField('bb'))
  1452. # Set fields.
  1453. msg.optional_int32 = 1
  1454. msg.optional_float = 1.0
  1455. msg.optional_string = '123'
  1456. msg.optional_nested_message.bb = 1
  1457. self.assertTrue(msg.HasField('optional_int32'))
  1458. self.assertTrue(msg.HasField('optional_float'))
  1459. self.assertTrue(msg.HasField('optional_string'))
  1460. self.assertTrue(msg.HasField('optional_nested_message'))
  1461. self.assertTrue(msg.optional_nested_message.HasField('bb'))
  1462. # Set to default value does not clear the fields
  1463. msg.optional_int32 = 0
  1464. msg.optional_float = 0.0
  1465. msg.optional_string = ''
  1466. msg.optional_nested_message.bb = 0
  1467. self.assertTrue(msg.HasField('optional_int32'))
  1468. self.assertTrue(msg.HasField('optional_float'))
  1469. self.assertTrue(msg.HasField('optional_string'))
  1470. self.assertTrue(msg.HasField('optional_nested_message'))
  1471. self.assertTrue(msg.optional_nested_message.HasField('bb'))
  1472. # Test serialize
  1473. msg2 = test_proto3_optional_pb2.TestProto3Optional()
  1474. msg2.ParseFromString(msg.SerializeToString())
  1475. self.assertTrue(msg2.HasField('optional_int32'))
  1476. self.assertTrue(msg2.HasField('optional_float'))
  1477. self.assertTrue(msg2.HasField('optional_string'))
  1478. self.assertTrue(msg2.HasField('optional_nested_message'))
  1479. self.assertTrue(msg2.optional_nested_message.HasField('bb'))
  1480. self.assertEqual(msg.WhichOneof('_optional_int32'), 'optional_int32')
  1481. # Clear these fields.
  1482. msg.ClearField('optional_int32')
  1483. msg.ClearField('optional_float')
  1484. msg.ClearField('optional_string')
  1485. msg.ClearField('optional_nested_message')
  1486. self.assertFalse(msg.HasField('optional_int32'))
  1487. self.assertFalse(msg.HasField('optional_float'))
  1488. self.assertFalse(msg.HasField('optional_string'))
  1489. self.assertFalse(msg.HasField('optional_nested_message'))
  1490. self.assertFalse(msg.optional_nested_message.HasField('bb'))
  1491. self.assertEqual(msg.WhichOneof('_optional_int32'), None)
  1492. def testAssignUnknownEnum(self):
  1493. """Assigning an unknown enum value is allowed and preserves the value."""
  1494. m = unittest_proto3_arena_pb2.TestAllTypes()
  1495. # Proto3 can assign unknown enums.
  1496. m.optional_nested_enum = 1234567
  1497. self.assertEqual(1234567, m.optional_nested_enum)
  1498. m.repeated_nested_enum.append(22334455)
  1499. self.assertEqual(22334455, m.repeated_nested_enum[0])
  1500. # Assignment is a different code path than append for the C++ impl.
  1501. m.repeated_nested_enum[0] = 7654321
  1502. self.assertEqual(7654321, m.repeated_nested_enum[0])
  1503. serialized = m.SerializeToString()
  1504. m2 = unittest_proto3_arena_pb2.TestAllTypes()
  1505. m2.ParseFromString(serialized)
  1506. self.assertEqual(1234567, m2.optional_nested_enum)
  1507. self.assertEqual(7654321, m2.repeated_nested_enum[0])
  1508. # Map isn't really a proto3-only feature. But there is no proto2 equivalent
  1509. # of google/protobuf/map_unittest.proto right now, so it's not easy to
  1510. # test both with the same test like we do for the other proto2/proto3 tests.
  1511. # (google/protobuf/map_proto2_unittest.proto is very different in the set
  1512. # of messages and fields it contains).
  1513. def testScalarMapDefaults(self):
  1514. msg = map_unittest_pb2.TestMap()
  1515. # Scalars start out unset.
  1516. self.assertFalse(-123 in msg.map_int32_int32)
  1517. self.assertFalse(-2**33 in msg.map_int64_int64)
  1518. self.assertFalse(123 in msg.map_uint32_uint32)
  1519. self.assertFalse(2**33 in msg.map_uint64_uint64)
  1520. self.assertFalse(123 in msg.map_int32_double)
  1521. self.assertFalse(False in msg.map_bool_bool)
  1522. self.assertFalse('abc' in msg.map_string_string)
  1523. self.assertFalse(111 in msg.map_int32_bytes)
  1524. self.assertFalse(888 in msg.map_int32_enum)
  1525. # Accessing an unset key returns the default.
  1526. self.assertEqual(0, msg.map_int32_int32[-123])
  1527. self.assertEqual(0, msg.map_int64_int64[-2**33])
  1528. self.assertEqual(0, msg.map_uint32_uint32[123])
  1529. self.assertEqual(0, msg.map_uint64_uint64[2**33])
  1530. self.assertEqual(0.0, msg.map_int32_double[123])
  1531. self.assertTrue(isinstance(msg.map_int32_double[123], float))
  1532. self.assertEqual(False, msg.map_bool_bool[False])
  1533. self.assertTrue(isinstance(msg.map_bool_bool[False], bool))
  1534. self.assertEqual('', msg.map_string_string['abc'])
  1535. self.assertEqual(b'', msg.map_int32_bytes[111])
  1536. self.assertEqual(0, msg.map_int32_enum[888])
  1537. # It also sets the value in the map
  1538. self.assertTrue(-123 in msg.map_int32_int32)
  1539. self.assertTrue(-2**33 in msg.map_int64_int64)
  1540. self.assertTrue(123 in msg.map_uint32_uint32)
  1541. self.assertTrue(2**33 in msg.map_uint64_uint64)
  1542. self.assertTrue(123 in msg.map_int32_double)
  1543. self.assertTrue(False in msg.map_bool_bool)
  1544. self.assertTrue('abc' in msg.map_string_string)
  1545. self.assertTrue(111 in msg.map_int32_bytes)
  1546. self.assertTrue(888 in msg.map_int32_enum)
  1547. self.assertIsInstance(msg.map_string_string['abc'], six.text_type)
  1548. # Accessing an unset key still throws TypeError if the type of the key
  1549. # is incorrect.
  1550. with self.assertRaises(TypeError):
  1551. msg.map_string_string[123]
  1552. with self.assertRaises(TypeError):
  1553. 123 in msg.map_string_string
  1554. def testMapGet(self):
  1555. # Need to test that get() properly returns the default, even though the dict
  1556. # has defaultdict-like semantics.
  1557. msg = map_unittest_pb2.TestMap()
  1558. self.assertIsNone(msg.map_int32_int32.get(5))
  1559. self.assertEqual(10, msg.map_int32_int32.get(5, 10))
  1560. self.assertEqual(10, msg.map_int32_int32.get(key=5, default=10))
  1561. self.assertIsNone(msg.map_int32_int32.get(5))
  1562. msg.map_int32_int32[5] = 15
  1563. self.assertEqual(15, msg.map_int32_int32.get(5))
  1564. self.assertEqual(15, msg.map_int32_int32.get(5))
  1565. with self.assertRaises(TypeError):
  1566. msg.map_int32_int32.get('')
  1567. self.assertIsNone(msg.map_int32_foreign_message.get(5))
  1568. self.assertEqual(10, msg.map_int32_foreign_message.get(5, 10))
  1569. self.assertEqual(10, msg.map_int32_foreign_message.get(key=5, default=10))
  1570. submsg = msg.map_int32_foreign_message[5]
  1571. self.assertIs(submsg, msg.map_int32_foreign_message.get(5))
  1572. with self.assertRaises(TypeError):
  1573. msg.map_int32_foreign_message.get('')
  1574. def testScalarMap(self):
  1575. msg = map_unittest_pb2.TestMap()
  1576. self.assertEqual(0, len(msg.map_int32_int32))
  1577. self.assertFalse(5 in msg.map_int32_int32)
  1578. msg.map_int32_int32[-123] = -456
  1579. msg.map_int64_int64[-2**33] = -2**34
  1580. msg.map_uint32_uint32[123] = 456
  1581. msg.map_uint64_uint64[2**33] = 2**34
  1582. msg.map_int32_float[2] = 1.2
  1583. msg.map_int32_double[1] = 3.3
  1584. msg.map_string_string['abc'] = '123'
  1585. msg.map_bool_bool[True] = True
  1586. msg.map_int32_enum[888] = 2
  1587. # Unknown numeric enum is supported in proto3.
  1588. msg.map_int32_enum[123] = 456
  1589. self.assertEqual([], msg.FindInitializationErrors())
  1590. self.assertEqual(1, len(msg.map_string_string))
  1591. # Bad key.
  1592. with self.assertRaises(TypeError):
  1593. msg.map_string_string[123] = '123'
  1594. # Verify that trying to assign a bad key doesn't actually add a member to
  1595. # the map.
  1596. self.assertEqual(1, len(msg.map_string_string))
  1597. # Bad value.
  1598. with self.assertRaises(TypeError):
  1599. msg.map_string_string['123'] = 123
  1600. serialized = msg.SerializeToString()
  1601. msg2 = map_unittest_pb2.TestMap()
  1602. msg2.ParseFromString(serialized)
  1603. # Bad key.
  1604. with self.assertRaises(TypeError):
  1605. msg2.map_string_string[123] = '123'
  1606. # Bad value.
  1607. with self.assertRaises(TypeError):
  1608. msg2.map_string_string['123'] = 123
  1609. self.assertEqual(-456, msg2.map_int32_int32[-123])
  1610. self.assertEqual(-2**34, msg2.map_int64_int64[-2**33])
  1611. self.assertEqual(456, msg2.map_uint32_uint32[123])
  1612. self.assertEqual(2**34, msg2.map_uint64_uint64[2**33])
  1613. self.assertAlmostEqual(1.2, msg.map_int32_float[2])
  1614. self.assertEqual(3.3, msg.map_int32_double[1])
  1615. self.assertEqual('123', msg2.map_string_string['abc'])
  1616. self.assertEqual(True, msg2.map_bool_bool[True])
  1617. self.assertEqual(2, msg2.map_int32_enum[888])
  1618. self.assertEqual(456, msg2.map_int32_enum[123])
  1619. self.assertEqual('{-123: -456}',
  1620. str(msg2.map_int32_int32))
  1621. def testMapEntryAlwaysSerialized(self):
  1622. msg = map_unittest_pb2.TestMap()
  1623. msg.map_int32_int32[0] = 0
  1624. msg.map_string_string[''] = ''
  1625. self.assertEqual(msg.ByteSize(), 12)
  1626. self.assertEqual(b'\n\x04\x08\x00\x10\x00r\x04\n\x00\x12\x00',
  1627. msg.SerializeToString())
  1628. def testStringUnicodeConversionInMap(self):
  1629. msg = map_unittest_pb2.TestMap()
  1630. unicode_obj = u'\u1234'
  1631. bytes_obj = unicode_obj.encode('utf8')
  1632. msg.map_string_string[bytes_obj] = bytes_obj
  1633. (key, value) = list(msg.map_string_string.items())[0]
  1634. self.assertEqual(key, unicode_obj)
  1635. self.assertEqual(value, unicode_obj)
  1636. self.assertIsInstance(key, six.text_type)
  1637. self.assertIsInstance(value, six.text_type)
  1638. def testMessageMap(self):
  1639. msg = map_unittest_pb2.TestMap()
  1640. self.assertEqual(0, len(msg.map_int32_foreign_message))
  1641. self.assertFalse(5 in msg.map_int32_foreign_message)
  1642. msg.map_int32_foreign_message[123]
  1643. # get_or_create() is an alias for getitem.
  1644. msg.map_int32_foreign_message.get_or_create(-456)
  1645. self.assertEqual(2, len(msg.map_int32_foreign_message))
  1646. self.assertIn(123, msg.map_int32_foreign_message)
  1647. self.assertIn(-456, msg.map_int32_foreign_message)
  1648. self.assertEqual(2, len(msg.map_int32_foreign_message))
  1649. # Bad key.
  1650. with self.assertRaises(TypeError):
  1651. msg.map_int32_foreign_message['123']
  1652. # Can't assign directly to submessage.
  1653. with self.assertRaises(ValueError):
  1654. msg.map_int32_foreign_message[999] = msg.map_int32_foreign_message[123]
  1655. # Verify that trying to assign a bad key doesn't actually add a member to
  1656. # the map.
  1657. self.assertEqual(2, len(msg.map_int32_foreign_message))
  1658. serialized = msg.SerializeToString()
  1659. msg2 = map_unittest_pb2.TestMap()
  1660. msg2.ParseFromString(serialized)
  1661. self.assertEqual(2, len(msg2.map_int32_foreign_message))
  1662. self.assertIn(123, msg2.map_int32_foreign_message)
  1663. self.assertIn(-456, msg2.map_int32_foreign_message)
  1664. self.assertEqual(2, len(msg2.map_int32_foreign_message))
  1665. msg2.map_int32_foreign_message[123].c = 1
  1666. # TODO(jieluo): Fix text format for message map.
  1667. self.assertIn(str(msg2.map_int32_foreign_message),
  1668. ('{-456: , 123: c: 1\n}', '{123: c: 1\n, -456: }'))
  1669. def testNestedMessageMapItemDelete(self):
  1670. msg = map_unittest_pb2.TestMap()
  1671. msg.map_int32_all_types[1].optional_nested_message.bb = 1
  1672. del msg.map_int32_all_types[1]
  1673. msg.map_int32_all_types[2].optional_nested_message.bb = 2
  1674. self.assertEqual(1, len(msg.map_int32_all_types))
  1675. msg.map_int32_all_types[1].optional_nested_message.bb = 1
  1676. self.assertEqual(2, len(msg.map_int32_all_types))
  1677. serialized = msg.SerializeToString()
  1678. msg2 = map_unittest_pb2.TestMap()
  1679. msg2.ParseFromString(serialized)
  1680. keys = [1, 2]
  1681. # The loop triggers PyErr_Occurred() in c extension.
  1682. for key in keys:
  1683. del msg2.map_int32_all_types[key]
  1684. def testMapByteSize(self):
  1685. msg = map_unittest_pb2.TestMap()
  1686. msg.map_int32_int32[1] = 1
  1687. size = msg.ByteSize()
  1688. msg.map_int32_int32[1] = 128
  1689. self.assertEqual(msg.ByteSize(), size + 1)
  1690. msg.map_int32_foreign_message[19].c = 1
  1691. size = msg.ByteSize()
  1692. msg.map_int32_foreign_message[19].c = 128
  1693. self.assertEqual(msg.ByteSize(), size + 1)
  1694. def testMergeFrom(self):
  1695. msg = map_unittest_pb2.TestMap()
  1696. msg.map_int32_int32[12] = 34
  1697. msg.map_int32_int32[56] = 78
  1698. msg.map_int64_int64[22] = 33
  1699. msg.map_int32_foreign_message[111].c = 5
  1700. msg.map_int32_foreign_message[222].c = 10
  1701. msg2 = map_unittest_pb2.TestMap()
  1702. msg2.map_int32_int32[12] = 55
  1703. msg2.map_int64_int64[88] = 99
  1704. msg2.map_int32_foreign_message[222].c = 15
  1705. msg2.map_int32_foreign_message[222].d = 20
  1706. old_map_value = msg2.map_int32_foreign_message[222]
  1707. msg2.MergeFrom(msg)
  1708. # Compare with expected message instead of call
  1709. # msg2.map_int32_foreign_message[222] to make sure MergeFrom does not
  1710. # sync with repeated field and there is no duplicated keys.
  1711. expected_msg = map_unittest_pb2.TestMap()
  1712. expected_msg.CopyFrom(msg)
  1713. expected_msg.map_int64_int64[88] = 99
  1714. self.assertEqual(msg2, expected_msg)
  1715. self.assertEqual(34, msg2.map_int32_int32[12])
  1716. self.assertEqual(78, msg2.map_int32_int32[56])
  1717. self.assertEqual(33, msg2.map_int64_int64[22])
  1718. self.assertEqual(99, msg2.map_int64_int64[88])
  1719. self.assertEqual(5, msg2.map_int32_foreign_message[111].c)
  1720. self.assertEqual(10, msg2.map_int32_foreign_message[222].c)
  1721. self.assertFalse(msg2.map_int32_foreign_message[222].HasField('d'))
  1722. if api_implementation.Type() != 'cpp':
  1723. # During the call to MergeFrom(), the C++ implementation will have
  1724. # deallocated the underlying message, but this is very difficult to detect
  1725. # properly. The line below is likely to cause a segmentation fault.
  1726. # With the Python implementation, old_map_value is just 'detached' from
  1727. # the main message. Using it will not crash of course, but since it still
  1728. # have a reference to the parent message I'm sure we can find interesting
  1729. # ways to cause inconsistencies.
  1730. self.assertEqual(15, old_map_value.c)
  1731. # Verify that there is only one entry per key, even though the MergeFrom
  1732. # may have internally created multiple entries for a single key in the
  1733. # list representation.
  1734. as_dict = {}
  1735. for key in msg2.map_int32_foreign_message:
  1736. self.assertFalse(key in as_dict)
  1737. as_dict[key] = msg2.map_int32_foreign_message[key].c
  1738. self.assertEqual({111: 5, 222: 10}, as_dict)
  1739. # Special case: test that delete of item really removes the item, even if
  1740. # there might have physically been duplicate keys due to the previous merge.
  1741. # This is only a special case for the C++ implementation which stores the
  1742. # map as an array.
  1743. del msg2.map_int32_int32[12]
  1744. self.assertFalse(12 in msg2.map_int32_int32)
  1745. del msg2.map_int32_foreign_message[222]
  1746. self.assertFalse(222 in msg2.map_int32_foreign_message)
  1747. with self.assertRaises(TypeError):
  1748. del msg2.map_int32_foreign_message['']
  1749. def testMapMergeFrom(self):
  1750. msg = map_unittest_pb2.TestMap()
  1751. msg.map_int32_int32[12] = 34
  1752. msg.map_int32_int32[56] = 78
  1753. msg.map_int64_int64[22] = 33
  1754. msg.map_int32_foreign_message[111].c = 5
  1755. msg.map_int32_foreign_message[222].c = 10
  1756. msg2 = map_unittest_pb2.TestMap()
  1757. msg2.map_int32_int32[12] = 55
  1758. msg2.map_int64_int64[88] = 99
  1759. msg2.map_int32_foreign_message[222].c = 15
  1760. msg2.map_int32_foreign_message[222].d = 20
  1761. msg2.map_int32_int32.MergeFrom(msg.map_int32_int32)
  1762. self.assertEqual(34, msg2.map_int32_int32[12])
  1763. self.assertEqual(78, msg2.map_int32_int32[56])
  1764. msg2.map_int64_int64.MergeFrom(msg.map_int64_int64)
  1765. self.assertEqual(33, msg2.map_int64_int64[22])
  1766. self.assertEqual(99, msg2.map_int64_int64[88])
  1767. msg2.map_int32_foreign_message.MergeFrom(msg.map_int32_foreign_message)
  1768. # Compare with expected message instead of call
  1769. # msg.map_int32_foreign_message[222] to make sure MergeFrom does not
  1770. # sync with repeated field and no duplicated keys.
  1771. expected_msg = map_unittest_pb2.TestMap()
  1772. expected_msg.CopyFrom(msg)
  1773. expected_msg.map_int64_int64[88] = 99
  1774. self.assertEqual(msg2, expected_msg)
  1775. # Test when cpp extension cache a map.
  1776. m1 = map_unittest_pb2.TestMap()
  1777. m2 = map_unittest_pb2.TestMap()
  1778. self.assertEqual(m1.map_int32_foreign_message,
  1779. m1.map_int32_foreign_message)
  1780. m2.map_int32_foreign_message[123].c = 10
  1781. m1.MergeFrom(m2)
  1782. self.assertEqual(10, m2.map_int32_foreign_message[123].c)
  1783. # Test merge maps within different message types.
  1784. m1 = map_unittest_pb2.TestMap()
  1785. m2 = map_unittest_pb2.TestMessageMap()
  1786. m2.map_int32_message[123].optional_int32 = 10
  1787. m1.map_int32_all_types.MergeFrom(m2.map_int32_message)
  1788. self.assertEqual(10, m1.map_int32_all_types[123].optional_int32)
  1789. # Test overwrite message value map
  1790. msg = map_unittest_pb2.TestMap()
  1791. msg.map_int32_foreign_message[222].c = 123
  1792. msg2 = map_unittest_pb2.TestMap()
  1793. msg2.map_int32_foreign_message[222].d = 20
  1794. msg.MergeFromString(msg2.SerializeToString())
  1795. self.assertEqual(msg.map_int32_foreign_message[222].d, 20)
  1796. self.assertNotEqual(msg.map_int32_foreign_message[222].c, 123)
  1797. # Merge a dict to map field is not accepted
  1798. with self.assertRaises(AttributeError):
  1799. m1.map_int32_all_types.MergeFrom(
  1800. {1: unittest_proto3_arena_pb2.TestAllTypes()})
  1801. def testMergeFromBadType(self):
  1802. msg = map_unittest_pb2.TestMap()
  1803. with self.assertRaisesRegexp(
  1804. TypeError,
  1805. r'Parameter to MergeFrom\(\) must be instance of same class: expected '
  1806. r'.*TestMap got int\.'):
  1807. msg.MergeFrom(1)
  1808. def testCopyFromBadType(self):
  1809. msg = map_unittest_pb2.TestMap()
  1810. with self.assertRaisesRegexp(
  1811. TypeError,
  1812. r'Parameter to [A-Za-z]*From\(\) must be instance of same class: '
  1813. r'expected .*TestMap got int\.'):
  1814. msg.CopyFrom(1)
  1815. def testIntegerMapWithLongs(self):
  1816. msg = map_unittest_pb2.TestMap()
  1817. msg.map_int32_int32[long(-123)] = long(-456)
  1818. msg.map_int64_int64[long(-2**33)] = long(-2**34)
  1819. msg.map_uint32_uint32[long(123)] = long(456)
  1820. msg.map_uint64_uint64[long(2**33)] = long(2**34)
  1821. serialized = msg.SerializeToString()
  1822. msg2 = map_unittest_pb2.TestMap()
  1823. msg2.ParseFromString(serialized)
  1824. self.assertEqual(-456, msg2.map_int32_int32[-123])
  1825. self.assertEqual(-2**34, msg2.map_int64_int64[-2**33])
  1826. self.assertEqual(456, msg2.map_uint32_uint32[123])
  1827. self.assertEqual(2**34, msg2.map_uint64_uint64[2**33])
  1828. def testMapAssignmentCausesPresence(self):
  1829. msg = map_unittest_pb2.TestMapSubmessage()
  1830. msg.test_map.map_int32_int32[123] = 456
  1831. serialized = msg.SerializeToString()
  1832. msg2 = map_unittest_pb2.TestMapSubmessage()
  1833. msg2.ParseFromString(serialized)
  1834. self.assertEqual(msg, msg2)
  1835. # Now test that various mutations of the map properly invalidate the
  1836. # cached size of the submessage.
  1837. msg.test_map.map_int32_int32[888] = 999
  1838. serialized = msg.SerializeToString()
  1839. msg2.ParseFromString(serialized)
  1840. self.assertEqual(msg, msg2)
  1841. msg.test_map.map_int32_int32.clear()
  1842. serialized = msg.SerializeToString()
  1843. msg2.ParseFromString(serialized)
  1844. self.assertEqual(msg, msg2)
  1845. def testMapAssignmentCausesPresenceForSubmessages(self):
  1846. msg = map_unittest_pb2.TestMapSubmessage()
  1847. msg.test_map.map_int32_foreign_message[123].c = 5
  1848. serialized = msg.SerializeToString()
  1849. msg2 = map_unittest_pb2.TestMapSubmessage()
  1850. msg2.ParseFromString(serialized)
  1851. self.assertEqual(msg, msg2)
  1852. # Now test that various mutations of the map properly invalidate the
  1853. # cached size of the submessage.
  1854. msg.test_map.map_int32_foreign_message[888].c = 7
  1855. serialized = msg.SerializeToString()
  1856. msg2.ParseFromString(serialized)
  1857. self.assertEqual(msg, msg2)
  1858. msg.test_map.map_int32_foreign_message[888].MergeFrom(
  1859. msg.test_map.map_int32_foreign_message[123])
  1860. serialized = msg.SerializeToString()
  1861. msg2.ParseFromString(serialized)
  1862. self.assertEqual(msg, msg2)
  1863. msg.test_map.map_int32_foreign_message.clear()
  1864. serialized = msg.SerializeToString()
  1865. msg2.ParseFromString(serialized)
  1866. self.assertEqual(msg, msg2)
  1867. def testModifyMapWhileIterating(self):
  1868. msg = map_unittest_pb2.TestMap()
  1869. string_string_iter = iter(msg.map_string_string)
  1870. int32_foreign_iter = iter(msg.map_int32_foreign_message)
  1871. msg.map_string_string['abc'] = '123'
  1872. msg.map_int32_foreign_message[5].c = 5
  1873. with self.assertRaises(RuntimeError):
  1874. for key in string_string_iter:
  1875. pass
  1876. with self.assertRaises(RuntimeError):
  1877. for key in int32_foreign_iter:
  1878. pass
  1879. def testSubmessageMap(self):
  1880. msg = map_unittest_pb2.TestMap()
  1881. submsg = msg.map_int32_foreign_message[111]
  1882. self.assertIs(submsg, msg.map_int32_foreign_message[111])
  1883. self.assertIsInstance(submsg, unittest_pb2.ForeignMessage)
  1884. submsg.c = 5
  1885. serialized = msg.SerializeToString()
  1886. msg2 = map_unittest_pb2.TestMap()
  1887. msg2.ParseFromString(serialized)
  1888. self.assertEqual(5, msg2.map_int32_foreign_message[111].c)
  1889. # Doesn't allow direct submessage assignment.
  1890. with self.assertRaises(ValueError):
  1891. msg.map_int32_foreign_message[88] = unittest_pb2.ForeignMessage()
  1892. def testMapIteration(self):
  1893. msg = map_unittest_pb2.TestMap()
  1894. for k, v in msg.map_int32_int32.items():
  1895. # Should not be reached.
  1896. self.assertTrue(False)
  1897. msg.map_int32_int32[2] = 4
  1898. msg.map_int32_int32[3] = 6
  1899. msg.map_int32_int32[4] = 8
  1900. self.assertEqual(3, len(msg.map_int32_int32))
  1901. matching_dict = {2: 4, 3: 6, 4: 8}
  1902. self.assertMapIterEquals(msg.map_int32_int32.items(), matching_dict)
  1903. def testPython2Map(self):
  1904. if sys.version_info < (3,):
  1905. msg = map_unittest_pb2.TestMap()
  1906. msg.map_int32_int32[2] = 4
  1907. msg.map_int32_int32[3] = 6
  1908. msg.map_int32_int32[4] = 8
  1909. msg.map_int32_int32[5] = 10
  1910. map_int32 = msg.map_int32_int32
  1911. self.assertEqual(4, len(map_int32))
  1912. msg2 = map_unittest_pb2.TestMap()
  1913. msg2.ParseFromString(msg.SerializeToString())
  1914. def CheckItems(seq, iterator):
  1915. self.assertEqual(next(iterator), seq[0])
  1916. self.assertEqual(list(iterator), seq[1:])
  1917. CheckItems(map_int32.items(), map_int32.iteritems())
  1918. CheckItems(map_int32.keys(), map_int32.iterkeys())
  1919. CheckItems(map_int32.values(), map_int32.itervalues())
  1920. self.assertEqual(6, map_int32.get(3))
  1921. self.assertEqual(None, map_int32.get(999))
  1922. self.assertEqual(6, map_int32.pop(3))
  1923. self.assertEqual(0, map_int32.pop(3))
  1924. self.assertEqual(3, len(map_int32))
  1925. key, value = map_int32.popitem()
  1926. self.assertEqual(2 * key, value)
  1927. self.assertEqual(2, len(map_int32))
  1928. map_int32.clear()
  1929. self.assertEqual(0, len(map_int32))
  1930. with self.assertRaises(KeyError):
  1931. map_int32.popitem()
  1932. self.assertEqual(0, map_int32.setdefault(2))
  1933. self.assertEqual(1, len(map_int32))
  1934. map_int32.update(msg2.map_int32_int32)
  1935. self.assertEqual(4, len(map_int32))
  1936. with self.assertRaises(TypeError):
  1937. map_int32.update(msg2.map_int32_int32,
  1938. msg2.map_int32_int32)
  1939. with self.assertRaises(TypeError):
  1940. map_int32.update(0)
  1941. with self.assertRaises(TypeError):
  1942. map_int32.update(value=12)
  1943. def testMapItems(self):
  1944. # Map items used to have strange behaviors when use c extension. Because
  1945. # [] may reorder the map and invalidate any exsting iterators.
  1946. # TODO(jieluo): Check if [] reordering the map is a bug or intended
  1947. # behavior.
  1948. msg = map_unittest_pb2.TestMap()
  1949. msg.map_string_string['local_init_op'] = ''
  1950. msg.map_string_string['trainable_variables'] = ''
  1951. msg.map_string_string['variables'] = ''
  1952. msg.map_string_string['init_op'] = ''
  1953. msg.map_string_string['summaries'] = ''
  1954. items1 = msg.map_string_string.items()
  1955. items2 = msg.map_string_string.items()
  1956. self.assertEqual(items1, items2)
  1957. def testMapDeterministicSerialization(self):
  1958. golden_data = (b'r\x0c\n\x07init_op\x12\x01d'
  1959. b'r\n\n\x05item1\x12\x01e'
  1960. b'r\n\n\x05item2\x12\x01f'
  1961. b'r\n\n\x05item3\x12\x01g'
  1962. b'r\x0b\n\x05item4\x12\x02QQ'
  1963. b'r\x12\n\rlocal_init_op\x12\x01a'
  1964. b'r\x0e\n\tsummaries\x12\x01e'
  1965. b'r\x18\n\x13trainable_variables\x12\x01b'
  1966. b'r\x0e\n\tvariables\x12\x01c')
  1967. msg = map_unittest_pb2.TestMap()
  1968. msg.map_string_string['local_init_op'] = 'a'
  1969. msg.map_string_string['trainable_variables'] = 'b'
  1970. msg.map_string_string['variables'] = 'c'
  1971. msg.map_string_string['init_op'] = 'd'
  1972. msg.map_string_string['summaries'] = 'e'
  1973. msg.map_string_string['item1'] = 'e'
  1974. msg.map_string_string['item2'] = 'f'
  1975. msg.map_string_string['item3'] = 'g'
  1976. msg.map_string_string['item4'] = 'QQ'
  1977. # If deterministic serialization is not working correctly, this will be
  1978. # "flaky" depending on the exact python dict hash seed.
  1979. #
  1980. # Fortunately, there are enough items in this map that it is extremely
  1981. # unlikely to ever hit the "right" in-order combination, so the test
  1982. # itself should fail reliably.
  1983. self.assertEqual(golden_data, msg.SerializeToString(deterministic=True))
  1984. def testMapIterationClearMessage(self):
  1985. # Iterator needs to work even if message and map are deleted.
  1986. msg = map_unittest_pb2.TestMap()
  1987. msg.map_int32_int32[2] = 4
  1988. msg.map_int32_int32[3] = 6
  1989. msg.map_int32_int32[4] = 8
  1990. it = msg.map_int32_int32.items()
  1991. del msg
  1992. matching_dict = {2: 4, 3: 6, 4: 8}
  1993. self.assertMapIterEquals(it, matching_dict)
  1994. def testMapConstruction(self):
  1995. msg = map_unittest_pb2.TestMap(map_int32_int32={1: 2, 3: 4})
  1996. self.assertEqual(2, msg.map_int32_int32[1])
  1997. self.assertEqual(4, msg.map_int32_int32[3])
  1998. msg = map_unittest_pb2.TestMap(
  1999. map_int32_foreign_message={3: unittest_pb2.ForeignMessage(c=5)})
  2000. self.assertEqual(5, msg.map_int32_foreign_message[3].c)
  2001. def testMapScalarFieldConstruction(self):
  2002. msg1 = map_unittest_pb2.TestMap()
  2003. msg1.map_int32_int32[1] = 42
  2004. msg2 = map_unittest_pb2.TestMap(map_int32_int32=msg1.map_int32_int32)
  2005. self.assertEqual(42, msg2.map_int32_int32[1])
  2006. def testMapMessageFieldConstruction(self):
  2007. msg1 = map_unittest_pb2.TestMap()
  2008. msg1.map_string_foreign_message['test'].c = 42
  2009. msg2 = map_unittest_pb2.TestMap(
  2010. map_string_foreign_message=msg1.map_string_foreign_message)
  2011. self.assertEqual(42, msg2.map_string_foreign_message['test'].c)
  2012. def testMapFieldRaisesCorrectError(self):
  2013. # Should raise a TypeError when given a non-iterable.
  2014. with self.assertRaises(TypeError):
  2015. map_unittest_pb2.TestMap(map_string_foreign_message=1)
  2016. def testMapValidAfterFieldCleared(self):
  2017. # Map needs to work even if field is cleared.
  2018. # For the C++ implementation this tests the correctness of
  2019. # MapContainer::Release()
  2020. msg = map_unittest_pb2.TestMap()
  2021. int32_map = msg.map_int32_int32
  2022. int32_map[2] = 4
  2023. int32_map[3] = 6
  2024. int32_map[4] = 8
  2025. msg.ClearField('map_int32_int32')
  2026. self.assertEqual(b'', msg.SerializeToString())
  2027. matching_dict = {2: 4, 3: 6, 4: 8}
  2028. self.assertMapIterEquals(int32_map.items(), matching_dict)
  2029. def testMessageMapValidAfterFieldCleared(self):
  2030. # Map needs to work even if field is cleared.
  2031. # For the C++ implementation this tests the correctness of
  2032. # MapContainer::Release()
  2033. msg = map_unittest_pb2.TestMap()
  2034. int32_foreign_message = msg.map_int32_foreign_message
  2035. int32_foreign_message[2].c = 5
  2036. msg.ClearField('map_int32_foreign_message')
  2037. self.assertEqual(b'', msg.SerializeToString())
  2038. self.assertTrue(2 in int32_foreign_message.keys())
  2039. def testMessageMapItemValidAfterTopMessageCleared(self):
  2040. # Message map item needs to work even if it is cleared.
  2041. # For the C++ implementation this tests the correctness of
  2042. # MapContainer::Release()
  2043. msg = map_unittest_pb2.TestMap()
  2044. msg.map_int32_all_types[2].optional_string = 'bar'
  2045. if api_implementation.Type() == 'cpp':
  2046. # Need to keep the map reference because of b/27942626.
  2047. # TODO(jieluo): Remove it.
  2048. unused_map = msg.map_int32_all_types # pylint: disable=unused-variable
  2049. msg_value = msg.map_int32_all_types[2]
  2050. msg.Clear()
  2051. # Reset to trigger sync between repeated field and map in c++.
  2052. msg.map_int32_all_types[3].optional_string = 'foo'
  2053. self.assertEqual(msg_value.optional_string, 'bar')
  2054. def testMapIterInvalidatedByClearField(self):
  2055. # Map iterator is invalidated when field is cleared.
  2056. # But this case does need to not crash the interpreter.
  2057. # For the C++ implementation this tests the correctness of
  2058. # ScalarMapContainer::Release()
  2059. msg = map_unittest_pb2.TestMap()
  2060. it = iter(msg.map_int32_int32)
  2061. msg.ClearField('map_int32_int32')
  2062. with self.assertRaises(RuntimeError):
  2063. for _ in it:
  2064. pass
  2065. it = iter(msg.map_int32_foreign_message)
  2066. msg.ClearField('map_int32_foreign_message')
  2067. with self.assertRaises(RuntimeError):
  2068. for _ in it:
  2069. pass
  2070. def testMapDelete(self):
  2071. msg = map_unittest_pb2.TestMap()
  2072. self.assertEqual(0, len(msg.map_int32_int32))
  2073. msg.map_int32_int32[4] = 6
  2074. self.assertEqual(1, len(msg.map_int32_int32))
  2075. with self.assertRaises(KeyError):
  2076. del msg.map_int32_int32[88]
  2077. del msg.map_int32_int32[4]
  2078. self.assertEqual(0, len(msg.map_int32_int32))
  2079. with self.assertRaises(KeyError):
  2080. del msg.map_int32_all_types[32]
  2081. def testMapsAreMapping(self):
  2082. msg = map_unittest_pb2.TestMap()
  2083. self.assertIsInstance(msg.map_int32_int32, collections_abc.Mapping)
  2084. self.assertIsInstance(msg.map_int32_int32, collections_abc.MutableMapping)
  2085. self.assertIsInstance(msg.map_int32_foreign_message, collections_abc.Mapping)
  2086. self.assertIsInstance(msg.map_int32_foreign_message,
  2087. collections_abc.MutableMapping)
  2088. def testMapsCompare(self):
  2089. msg = map_unittest_pb2.TestMap()
  2090. msg.map_int32_int32[-123] = -456
  2091. self.assertEqual(msg.map_int32_int32, msg.map_int32_int32)
  2092. self.assertEqual(msg.map_int32_foreign_message,
  2093. msg.map_int32_foreign_message)
  2094. self.assertNotEqual(msg.map_int32_int32, 0)
  2095. def testMapFindInitializationErrorsSmokeTest(self):
  2096. msg = map_unittest_pb2.TestMap()
  2097. msg.map_string_string['abc'] = '123'
  2098. msg.map_int32_int32[35] = 64
  2099. msg.map_string_foreign_message['foo'].c = 5
  2100. self.assertEqual(0, len(msg.FindInitializationErrors()))
  2101. @unittest.skipIf(sys.maxunicode == UCS2_MAXUNICODE, 'Skip for ucs2')
  2102. def testStrictUtf8Check(self):
  2103. # Test u'\ud801' is rejected at parser in both python2 and python3.
  2104. serialized = (b'r\x03\xed\xa0\x81')
  2105. msg = unittest_proto3_arena_pb2.TestAllTypes()
  2106. with self.assertRaises(Exception) as context:
  2107. msg.MergeFromString(serialized)
  2108. if api_implementation.Type() == 'python':
  2109. self.assertIn('optional_string', str(context.exception))
  2110. else:
  2111. self.assertIn('Error parsing message', str(context.exception))
  2112. # Test optional_string=u'😍' is accepted.
  2113. serialized = unittest_proto3_arena_pb2.TestAllTypes(
  2114. optional_string=u'😍').SerializeToString()
  2115. msg2 = unittest_proto3_arena_pb2.TestAllTypes()
  2116. msg2.MergeFromString(serialized)
  2117. self.assertEqual(msg2.optional_string, u'😍')
  2118. msg = unittest_proto3_arena_pb2.TestAllTypes(
  2119. optional_string=u'\ud001')
  2120. self.assertEqual(msg.optional_string, u'\ud001')
  2121. @unittest.skipIf(six.PY2, 'Surrogates are acceptable in python2')
  2122. def testSurrogatesInPython3(self):
  2123. # Surrogates like U+D83D is an invalid unicode character, it is
  2124. # supported by Python2 only because in some builds, unicode strings
  2125. # use 2-bytes code units. Since Python 3.3, we don't have this problem.
  2126. #
  2127. # Surrogates are utf16 code units, in a unicode string they are invalid
  2128. # characters even when they appear in pairs like u'\ud801\udc01'. Protobuf
  2129. # Python3 reject such cases at setters and parsers. Python2 accpect it
  2130. # to keep same features with the language itself. 'Unpaired pairs'
  2131. # like u'\ud801' are rejected at parsers when strict utf8 check is enabled
  2132. # in proto3 to keep same behavior with c extension.
  2133. # Surrogates are rejected at setters in Python3.
  2134. with self.assertRaises(ValueError):
  2135. unittest_proto3_arena_pb2.TestAllTypes(
  2136. optional_string=u'\ud801\udc01')
  2137. with self.assertRaises(ValueError):
  2138. unittest_proto3_arena_pb2.TestAllTypes(
  2139. optional_string=b'\xed\xa0\x81')
  2140. with self.assertRaises(ValueError):
  2141. unittest_proto3_arena_pb2.TestAllTypes(
  2142. optional_string=u'\ud801')
  2143. with self.assertRaises(ValueError):
  2144. unittest_proto3_arena_pb2.TestAllTypes(
  2145. optional_string=u'\ud801\ud801')
  2146. @unittest.skipIf(six.PY3 or sys.maxunicode == UCS2_MAXUNICODE,
  2147. 'Surrogates are rejected at setters in Python3')
  2148. def testSurrogatesInPython2(self):
  2149. # Test optional_string=u'\ud801\udc01'.
  2150. # surrogate pair is acceptable in python2.
  2151. msg = unittest_proto3_arena_pb2.TestAllTypes(
  2152. optional_string=u'\ud801\udc01')
  2153. # TODO(jieluo): Change pure python to have same behavior with c extension.
  2154. # Some build in python2 consider u'\ud801\udc01' and u'\U00010401' are
  2155. # equal, some are not equal.
  2156. if api_implementation.Type() == 'python':
  2157. self.assertEqual(msg.optional_string, u'\ud801\udc01')
  2158. else:
  2159. self.assertEqual(msg.optional_string, u'\U00010401')
  2160. serialized = msg.SerializeToString()
  2161. msg2 = unittest_proto3_arena_pb2.TestAllTypes()
  2162. msg2.MergeFromString(serialized)
  2163. self.assertEqual(msg2.optional_string, u'\U00010401')
  2164. # Python2 does not reject surrogates at setters.
  2165. msg = unittest_proto3_arena_pb2.TestAllTypes(
  2166. optional_string=b'\xed\xa0\x81')
  2167. unittest_proto3_arena_pb2.TestAllTypes(
  2168. optional_string=u'\ud801')
  2169. unittest_proto3_arena_pb2.TestAllTypes(
  2170. optional_string=u'\ud801\ud801')
  2171. @testing_refleaks.TestCase
  2172. class ValidTypeNamesTest(unittest.TestCase):
  2173. def assertImportFromName(self, msg, base_name):
  2174. # Parse <type 'module.class_name'> to extra 'some.name' as a string.
  2175. tp_name = str(type(msg)).split("'")[1]
  2176. valid_names = ('Repeated%sContainer' % base_name,
  2177. 'Repeated%sFieldContainer' % base_name)
  2178. self.assertTrue(any(tp_name.endswith(v) for v in valid_names),
  2179. '%r does end with any of %r' % (tp_name, valid_names))
  2180. parts = tp_name.split('.')
  2181. class_name = parts[-1]
  2182. module_name = '.'.join(parts[:-1])
  2183. __import__(module_name, fromlist=[class_name])
  2184. def testTypeNamesCanBeImported(self):
  2185. # If import doesn't work, pickling won't work either.
  2186. pb = unittest_pb2.TestAllTypes()
  2187. self.assertImportFromName(pb.repeated_int32, 'Scalar')
  2188. self.assertImportFromName(pb.repeated_nested_message, 'Composite')
  2189. @testing_refleaks.TestCase
  2190. class PackedFieldTest(unittest.TestCase):
  2191. def setMessage(self, message):
  2192. message.repeated_int32.append(1)
  2193. message.repeated_int64.append(1)
  2194. message.repeated_uint32.append(1)
  2195. message.repeated_uint64.append(1)
  2196. message.repeated_sint32.append(1)
  2197. message.repeated_sint64.append(1)
  2198. message.repeated_fixed32.append(1)
  2199. message.repeated_fixed64.append(1)
  2200. message.repeated_sfixed32.append(1)
  2201. message.repeated_sfixed64.append(1)
  2202. message.repeated_float.append(1.0)
  2203. message.repeated_double.append(1.0)
  2204. message.repeated_bool.append(True)
  2205. message.repeated_nested_enum.append(1)
  2206. def testPackedFields(self):
  2207. message = packed_field_test_pb2.TestPackedTypes()
  2208. self.setMessage(message)
  2209. golden_data = (b'\x0A\x01\x01'
  2210. b'\x12\x01\x01'
  2211. b'\x1A\x01\x01'
  2212. b'\x22\x01\x01'
  2213. b'\x2A\x01\x02'
  2214. b'\x32\x01\x02'
  2215. b'\x3A\x04\x01\x00\x00\x00'
  2216. b'\x42\x08\x01\x00\x00\x00\x00\x00\x00\x00'
  2217. b'\x4A\x04\x01\x00\x00\x00'
  2218. b'\x52\x08\x01\x00\x00\x00\x00\x00\x00\x00'
  2219. b'\x5A\x04\x00\x00\x80\x3f'
  2220. b'\x62\x08\x00\x00\x00\x00\x00\x00\xf0\x3f'
  2221. b'\x6A\x01\x01'
  2222. b'\x72\x01\x01')
  2223. self.assertEqual(golden_data, message.SerializeToString())
  2224. def testUnpackedFields(self):
  2225. message = packed_field_test_pb2.TestUnpackedTypes()
  2226. self.setMessage(message)
  2227. golden_data = (b'\x08\x01'
  2228. b'\x10\x01'
  2229. b'\x18\x01'
  2230. b'\x20\x01'
  2231. b'\x28\x02'
  2232. b'\x30\x02'
  2233. b'\x3D\x01\x00\x00\x00'
  2234. b'\x41\x01\x00\x00\x00\x00\x00\x00\x00'
  2235. b'\x4D\x01\x00\x00\x00'
  2236. b'\x51\x01\x00\x00\x00\x00\x00\x00\x00'
  2237. b'\x5D\x00\x00\x80\x3f'
  2238. b'\x61\x00\x00\x00\x00\x00\x00\xf0\x3f'
  2239. b'\x68\x01'
  2240. b'\x70\x01')
  2241. self.assertEqual(golden_data, message.SerializeToString())
  2242. @unittest.skipIf(api_implementation.Type() != 'cpp' or
  2243. sys.version_info < (2, 7),
  2244. 'explicit tests of the C++ implementation for PY27 and above')
  2245. @testing_refleaks.TestCase
  2246. class OversizeProtosTest(unittest.TestCase):
  2247. @classmethod
  2248. def setUpClass(cls):
  2249. # At the moment, reference cycles between DescriptorPool and Message classes
  2250. # are not detected and these objects are never freed.
  2251. # To avoid errors with ReferenceLeakChecker, we create the class only once.
  2252. file_desc = """
  2253. name: "f/f.msg2"
  2254. package: "f"
  2255. message_type {
  2256. name: "msg1"
  2257. field {
  2258. name: "payload"
  2259. number: 1
  2260. label: LABEL_OPTIONAL
  2261. type: TYPE_STRING
  2262. }
  2263. }
  2264. message_type {
  2265. name: "msg2"
  2266. field {
  2267. name: "field"
  2268. number: 1
  2269. label: LABEL_OPTIONAL
  2270. type: TYPE_MESSAGE
  2271. type_name: "msg1"
  2272. }
  2273. }
  2274. """
  2275. pool = descriptor_pool.DescriptorPool()
  2276. desc = descriptor_pb2.FileDescriptorProto()
  2277. text_format.Parse(file_desc, desc)
  2278. pool.Add(desc)
  2279. cls.proto_cls = message_factory.MessageFactory(pool).GetPrototype(
  2280. pool.FindMessageTypeByName('f.msg2'))
  2281. def setUp(self):
  2282. self.p = self.proto_cls()
  2283. self.p.field.payload = 'c' * (1024 * 1024 * 64 + 1)
  2284. self.p_serialized = self.p.SerializeToString()
  2285. def testAssertOversizeProto(self):
  2286. from google.protobuf.pyext._message import SetAllowOversizeProtos
  2287. SetAllowOversizeProtos(False)
  2288. q = self.proto_cls()
  2289. try:
  2290. q.ParseFromString(self.p_serialized)
  2291. except message.DecodeError as e:
  2292. self.assertEqual(str(e), 'Error parsing message')
  2293. def testSucceedOversizeProto(self):
  2294. from google.protobuf.pyext._message import SetAllowOversizeProtos
  2295. SetAllowOversizeProtos(True)
  2296. q = self.proto_cls()
  2297. q.ParseFromString(self.p_serialized)
  2298. self.assertEqual(self.p.field.payload, q.field.payload)
  2299. if __name__ == '__main__':
  2300. unittest.main()