def field_class_to_schematics_field(field: peewee.Field) -> BaseType:
if isinstance(field, peewee.ForeignKeyField):
field = field.rel_field
kwargs = {}
# 检查是否 require
if not ((field.default is not None) or field.null or field.sequence or isinstance(field, peewee.AutoField)):
kwargs['required'] = True
if field.help_text:
kwargs['metadata'] = {'description': field.help_text}
if isinstance(field, peewee.IntegerField):
return IntType(**kwargs)
elif isinstance(field, peewee.FloatField):
return FloatType(**kwargs)
elif isinstance(field, (PG_JSONField, PG_BinaryJSONField, SQLITE_JSONField)):
# 注意 SQLITE_JSONField 是一个 _StringField 所以要提前
return JSONType(**kwargs)
# HStore 貌似才应该对应 dict,json可以对应任意类型
# return JSONDictType(StringType, **kwargs)
elif isinstance(field, peewee.DateTimeField):
return DateTimeType(**kwargs)
elif isinstance(field, peewee.DateField):
return DateType(**kwargs)
elif isinstance(field, peewee._StringField):
return StringType(**kwargs)
elif isinstance(field, peewee.BooleanField):
return BooleanType(**kwargs)
elif isinstance(field, peewee.BlobField):
return BlobType(**kwargs)
elif isinstance(field, PG_ArrayField):
field: PG_ArrayField
return JSONListType(field_class_to_schematics_field(field._ArrayField__field), **kwargs)
# noinspection PyProtectedMember
def update(self, records: Iterable[DataRecord], values: SQLValuesToWrite, returning=False) -> Union[int, Iterable[DataRecord]]:
new_vals = {}
model = self.vcls.model
db = self.vcls.model._meta.database
fields = self.vcls._peewee_fields
cond = self._build_write_condition(records)
for k, v in values.items():
if k in fields:
field = fields[k]
is_array_field = isinstance(field, ArrayField)
if is_array_field:
if k in values.set_add_fields:
# 这里需要加 [v] 的原因是,params需要数组,举例来说为,[v1,v2,v3]
# v = SQL('%s || %%s' % field.column_name, [v])
v = SQL('(select ARRAY((select unnest(%s)) union (select unnest(%%s))))' % field.column_name, [v])
if k in values.set_remove_fields:
v = SQL('(select ARRAY((select unnest(%s)) except (select unnest(%%s))))' % field.column_name, [v])
# 尚未启用
# if k in values.array_append:
# v = SQL('array_append(%s, %%s)' % field.column_name, [v])
# if k in values.array_remove:
# v = SQL('array_remove(%s, %%s)' % field.column_name, [v])
else:
if k in values.incr_fields:
v = field + v
if k in values.decr_fields:
v = field - v
new_vals[k] = v
with db.atomic(), PeeweeContext(db):
if isinstance(db, peewee.PostgresqlDatabase):
q = model.update(**new_vals).where(cond)
if returning:
# cond: peewee.Expression
ret = q.returning(*model._meta.fields.values()).execute()
to_record = lambda x: PeeweeDataRecord(None, x, view=self.vcls)
items = map(to_record, ret)
return list(items)
else:
count = q.execute()
return count
else:
count = model.update(**new_vals).where(cond).execute()
if not returning: return count
to_record = lambda x: PeeweeDataRecord(None, x, view=self.vcls)
return list(map(to_record, model.select().where(cond).execute()))