Commit 5ebaa626 authored by Sylvain Thénault's avatar Sylvain Thénault
Browse files

[json] fix json serialization for recent simplejson implementation, and test encoding of entities

as with earlier simplejson implementation, iterencode internal stuff
is a generated function, we can't anymore rely on the _iterencode
overriding trick, so move on by stoping isinstance(Entity, dict).

This is a much heavier change than expected but it was expected
to be done at some point, so let's go that way.
parent b5d595b66c35
......@@ -19,6 +19,7 @@
__docformat__ = "restructuredtext en"
from copy import copy
from warnings import warn
from logilab.common import interface
......@@ -51,7 +52,7 @@ def greater_card(rschema, subjtypes, objtypes, index):
return '1'
class Entity(AppObject, dict):
class Entity(AppObject):
"""an entity instance has e_schema automagically set on
the class and instances has access to their issuing cursor.
......@@ -287,17 +288,17 @@ class Entity(AppObject, dict):
def __init__(self, req, rset=None, row=None, col=0):
AppObject.__init__(self, req, rset=rset, row=row, col=col)
dict.__init__(self)
self._cw_related_cache = {}
if rset is not None:
self.eid = rset[row][col]
else:
self.eid = None
self._cw_is_saved = True
self.cw_attr_cache = {}
def __repr__(self):
return '<Entity %s %s %s at %s>' % (
self.e_schema, self.eid, self.keys(), id(self))
self.e_schema, self.eid, self.cw_attr_cache.keys(), id(self))
def __json_encode__(self):
"""custom json dumps hook to dump the entity's eid
......@@ -316,12 +317,18 @@ class Entity(AppObject, dict):
def __cmp__(self, other):
raise NotImplementedError('comparison not implemented for %s' % self.__class__)
def __contains__(self, key):
return key in self.cw_attr_cache
def __iter__(self):
return iter(self.cw_attr_cache)
def __getitem__(self, key):
if key == 'eid':
warn('[3.7] entity["eid"] is deprecated, use entity.eid instead',
DeprecationWarning, stacklevel=2)
return self.eid
return super(Entity, self).__getitem__(key)
return self.cw_attr_cache[key]
def __setitem__(self, attr, value):
"""override __setitem__ to update self.edited_attributes.
......@@ -339,7 +346,7 @@ class Entity(AppObject, dict):
DeprecationWarning, stacklevel=2)
self.eid = value
else:
super(Entity, self).__setitem__(attr, value)
self.cw_attr_cache[attr] = value
# don't add attribute into skip_security if already in edited
# attributes, else we may accidentaly skip a desired security check
if hasattr(self, 'edited_attributes') and \
......@@ -363,13 +370,16 @@ class Entity(AppObject, dict):
del self.entity['load_left']
"""
super(Entity, self).__delitem__(attr)
del self.cw_attr_cache[attr]
if hasattr(self, 'edited_attributes'):
self.edited_attributes.remove(attr)
def get(self, key, default=None):
return self.cw_attr_cache.get(key, default)
def setdefault(self, attr, default):
"""override setdefault to update self.edited_attributes"""
super(Entity, self).setdefault(attr, default)
self.cw_attr_cache.setdefault(attr, default)
# don't add attribute into skip_security if already in edited
# attributes, else we may accidentaly skip a desired security check
if hasattr(self, 'edited_attributes') and \
......@@ -382,9 +392,9 @@ class Entity(AppObject, dict):
undesired changes introduced in the entity's dict. See `__delitem__`
"""
if default is _marker:
value = super(Entity, self).pop(attr)
value = self.cw_attr_cache.pop(attr)
else:
value = super(Entity, self).pop(attr, default)
value = self.cw_attr_cache.pop(attr, default)
if hasattr(self, 'edited_attributes') and attr in self.edited_attributes:
self.edited_attributes.remove(attr)
return value
......@@ -556,6 +566,12 @@ class Entity(AppObject, dict):
# entity cloning ##########################################################
def cw_copy(self):
thecopy = copy(self)
thecopy.cw_attr_cache = copy(self.cw_attr_cache)
thecopy._cw_related_cache = {}
return thecopy
def copy_relations(self, ceid): # XXX cw_copy_relations
"""copy relations of the object with the given eid on this
object (this method is called on the newly created copy, and
......@@ -668,7 +684,7 @@ class Entity(AppObject, dict):
selected = []
for attr in (attributes or self._cw_to_complete_attributes(skip_bytes, skip_pwd)):
# if attribute already in entity, nothing to do
if self.has_key(attr):
if self.cw_attr_cache.has_key(attr):
continue
# case where attribute must be completed, but is not yet in entity
var = varmaker.next()
......@@ -727,7 +743,7 @@ class Entity(AppObject, dict):
:param name: name of the attribute to get
"""
try:
value = self[name]
value = self.cw_attr_cache[name]
except KeyError:
if not self.cw_is_saved():
return None
......@@ -952,7 +968,7 @@ class Entity(AppObject, dict):
# clear attributes cache
haseid = 'eid' in self
self._cw_completed = False
self.clear()
self.cw_attr_cache.clear()
# clear relations cache
self.cw_clear_relation_cache()
# rest path unique cache
......@@ -1020,7 +1036,7 @@ class Entity(AppObject, dict):
This method is for internal use, you should not use it.
"""
super(Entity, self).__setitem__(attr, value)
self.cw_attr_cache[attr] = value
def _cw_clear_local_perm_cache(self, action):
for rqlexpr in self.e_schema.get_rqlexprs(action):
......@@ -1037,7 +1053,7 @@ class Entity(AppObject, dict):
def _cw_set_defaults(self):
"""set default values according to the schema"""
for attr, value in self.e_schema.defaults():
if not self.has_key(attr):
if not self.cw_attr_cache.has_key(attr):
self[str(attr)] = value
def _cw_check(self, creation=False):
......
......@@ -17,8 +17,8 @@
# with CubicWeb. If not, see <http://www.gnu.org/licenses/>.
"""Helper classes to execute RQL queries on a set of sources, performing
security checking and data aggregation.
"""
from __future__ import with_statement
__docformat__ = "restructuredtext en"
......
......@@ -910,7 +910,7 @@ class Repository(object):
self._extid_cache[cachekey] = eid
self._type_source_cache[eid] = (etype, source.uri, extid)
entity = source.before_entity_insertion(session, extid, etype, eid)
entity.edited_attributes = set(entity)
entity.edited_attributes = set(entity.cw_attr_cache)
if source.should_call_hooks:
self.hm.call_hooks('before_add_entity', session, entity=entity)
# XXX call add_info with complete=False ?
......@@ -1021,7 +1021,7 @@ class Repository(object):
"""
# init edited_attributes before calling before_add_entity hooks
entity._cw_is_saved = False # entity has an eid but is not yet saved
entity.edited_attributes = set(entity) # XXX cw_edited_attributes
entity.edited_attributes = set(entity.cw_attr_cache) # XXX cw_edited_attributes
eschema = entity.e_schema
source = self.locate_etype_source(entity.__regid__)
# allocate an eid to the entity before calling hooks
......@@ -1036,7 +1036,7 @@ class Repository(object):
# XXX use entity.keys here since edited_attributes is not updated for
# inline relations XXX not true, right? (see edited_attributes
# affectation above)
for attr in entity.iterkeys():
for attr in entity.cw_attr_cache.iterkeys():
rschema = eschema.subjrels[attr]
if not rschema.final: # inlined relation
relations.append((attr, entity[attr]))
......
......@@ -22,8 +22,6 @@ from __future__ import with_statement
__docformat__ = "restructuredtext en"
from copy import copy
from rql.stmts import Union, Select
from rql.nodes import Constant, Relation
......@@ -479,7 +477,7 @@ class InsertRelationsStep(Step):
result = [[]]
for row in result:
# get a new entity definition for this row
edef = copy(base_edef)
edef = base_edef.cw_copy()
# complete this entity def using row values
index = 0
for rtype, rorder, value in self.rdefs:
......
......@@ -15,16 +15,16 @@
#
# You should have received a copy of the GNU Lesser General Public License along
# with CubicWeb. If not, see <http://www.gnu.org/licenses/>.
"""unit tests for module cubicweb.utils
"""
"""unit tests for module cubicweb.utils"""
import re
import decimal
import datetime
from logilab.common.testlib import TestCase, unittest_main
from cubicweb.utils import make_uid, UStringIO, SizeConstrainedList, RepeatList
from cubicweb.entity import Entity
try:
from cubicweb.utils import CubicWebJsonEncoder, json
......@@ -99,6 +99,7 @@ class RepeatListTC(TestCase):
l.pop(2)
self.assertEquals(l, [(1, 3)]*2)
class SizeConstrainedListTC(TestCase):
def test_append(self):
......@@ -117,6 +118,7 @@ class SizeConstrainedListTC(TestCase):
l.extend(extension)
yield self.assertEquals, l, expected
class JSONEncoderTC(TestCase):
def setUp(self):
if json is None:
......@@ -136,6 +138,20 @@ class JSONEncoderTC(TestCase):
def test_encoding_decimal(self):
self.assertEquals(self.encode(decimal.Decimal('1.2')), '1.2')
def test_encoding_bare_entity(self):
e = Entity(None)
e['pouet'] = 'hop'
e.eid = 2
self.assertEquals(json.loads(self.encode(e)),
{'pouet': 'hop', 'eid': 2})
def test_encoding_entity_in_list(self):
e = Entity(None)
e['pouet'] = 'hop'
e.eid = 2
self.assertEquals(json.loads(self.encode([e])),
[{'pouet': 'hop', 'eid': 2}])
def test_encoding_unknown_stuff(self):
self.assertEquals(self.encode(TestCase), 'null')
......
......@@ -335,21 +335,11 @@ else:
class CubicWebJsonEncoder(json.JSONEncoder):
"""define a json encoder to be able to encode yams std types"""
# _iterencode is the only entry point I've found to use a custom encode
# hook early enough: .default() is called if nothing else matched before,
# .iterencode() is called once on the main structure to encode and then
# never gets called again.
# For the record, our main use case is in FormValidateController with:
# json.dumps((status, args, entity), cls=CubicWebJsonEncoder)
# where we want all the entity attributes, including eid, to be part
# of the json object dumped.
# This would have once more been easier if Entity didn't extend dict.
def _iterencode(self, obj, markers=None):
if hasattr(obj, '__json_encode__'):
obj = obj.__json_encode__()
return json.JSONEncoder._iterencode(self, obj, markers)
def default(self, obj):
if hasattr(obj, 'eid'):
d = obj.cw_attr_cache.copy()
d['eid'] = obj.eid
return d
if isinstance(obj, datetime.datetime):
return obj.strftime('%Y/%m/%d %H:%M:%S')
elif isinstance(obj, datetime.date):
......
......@@ -395,12 +395,12 @@ class OneWeekCal(EntityView):
# colors here are class names defined in cubicweb.css
colors = [ "col%x" % i for i in range(12) ]
next_color_index = 0
done_tasks = []
done_tasks = set()
for row in xrange(self.cw_rset.rowcount):
task = self.cw_rset.get_entity(row, 0)
if task in done_tasks:
if task.eid in done_tasks:
continue
done_tasks.append(task)
done_tasks.add(task.eid)
the_dates = []
icalendarable = task.cw_adapt_to('ICalendarable')
tstart = icalendarable.start
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment