root/trac/hacks/marketplugin/0.9/tracmarket/db.py

Revision 117 (checked in by stevegt, 6 years ago)

checkpoint after mss tutorial, before moving order validity checks until after order entry into db

Line 
1 import os
2 import time
3 import types
4 from trac.core import Component
5
6 # XXX handle OperationalError for locked db, in select, update, and
7 # insert; also find all standalone selects and handle there
8
9 def kw2sql(*args, **kwargs):
10     """
11
12         >>> kw2sql(foo='a', bar='b')
13         ('foo = %s AND bar = %s', ['a', 'b'])
14         >>> kw2sql(', ', foo='a', bar='b')
15         ('foo = %s , bar = %s', ['a', 'b'])
16         >>> kw2sql('OR', foo='a', bar='b')
17         ('foo = %s OR bar = %s', ['a', 'b'])
18         >>> kw2sql('OR', foo='a', bar=('>', 'b'))
19         ('foo = %s OR bar > %s', ['a', 'b'])
20         >>> kw2sql('OR', dict(foo='a'), dict(foo='b'))
21         ('(foo = %s) OR (foo = %s)', ['a', 'b'])
22         >>> kw2sql('OR', dict(foo='a'), dict(foo='b'), bar=23, baz=42)
23         ('(foo = %s) OR (foo = %s) OR baz = %s OR bar = %s', ['a', 'b', 42, 23])
24         >>> class T(DbObject):
25         ...     NAME='footable'
26         >>> kw2sql(T, foo='a', bar='b')
27         ('footable.foo = %s AND footable.bar = %s', ['a', 'b'])
28
29     """
30     parms = []
31     terms = []
32     op = 'AND'
33     table_class = None
34     if args:
35         for arg in args:
36             if arg is None:
37                 continue
38             elif type(arg) is type(str()):
39                 op = arg
40             elif type(arg) in (type(list()), type(tuple())):
41                 s, p = kw2sql(table_class, *arg)
42                 if s:
43                     terms.append("(%s)" % s)
44                     parms += p
45             elif type(arg) is type(dict()):
46                 s, p = kw2sql(table_class, **arg)
47                 if s:
48                     terms.append("(%s)" % s)
49                     parms += p
50             elif issubclass(arg, DbObject):
51                 table_class = arg
52             else:
53                 assert False, arg
54     if kwargs:
55         table_name = None
56         if table_class:
57             table_name = table_class.NAME
58         for var, val in kwargs.items():
59             eq = '='
60             if val is None:
61                 continue
62             if type(val) in (type(list()), type(tuple())):
63                 eq = val[0]
64                 val = val[1]
65             if isinstance(val, Select):
66                 if table_name:
67                     terms.append('%s.%s %s (%s)' % (table_name, var, eq, val.sql))
68                 else:
69                     terms.append('%s %s (%s)' % (var, eq, val.sql))
70                 parms += val.parms
71             else:
72                 if table_name:
73                     terms.append("%s.%s %s %%s" % (table_name, var, eq))
74                 else:
75                     terms.append("%s %s %%s" % (var, eq))
76                 parms.append(val)
77     op = ' %s ' % op
78     sql = op.join(terms)
79     return sql, parms
80
81 class Select(object):
82
83     # XXX get rid of env
84     def __init__(self, env, db, cls, columns=None, append=None, parms=None,
85             order=None, limit=None, *args, **kwargs):
86         self.env = env
87         self.db = db
88         self.cls = cls
89         self.debug = os.environ.get("SQLDEBUG", None)
90         self.parms = []
91         sql = "SELECT "
92         if type(columns) is type('a'):
93             columns = columns.split(', ')
94         if not columns:
95             columns = '*'
96         sql += '%s ' % ', '.join(columns)
97         names = []
98         if type(cls) not in (type(list()), type(tuple())):
99             cls = (cls, )
100         for c in cls:
101             names.append(c.NAME)
102         sql += 'FROM %s ' % ', '.join(names)
103         if args or kwargs:
104             s, p = kw2sql(cls[0], *args, **kwargs)
105             if s:
106                 sql += "WHERE %s " % s
107                 self.parms += p
108                 if append:
109                     sql += "AND "
110         elif append:
111             sql += "WHERE "
112         if append:
113             sql += " %s " % append
114             if parms:
115                 self.parms += parms
116         if order:
117             sql += "ORDER BY %s " % order
118         if limit:
119             sql += "LIMIT %d " % limit
120         self.sql = sql
121
122     def cursor(self):
123         cursor = self.db.cursor()
124         if self.parms:
125             if self.debug:
126                 self.env.log.debug("'%s', %s" % (self.sql, self.parms))
127             cursor.execute(self.sql, self.parms)
128         else:
129             if self.debug:
130                 self.env.log.debug("'%s'" % (self.sql))
131             cursor.execute(self.sql)
132         return cursor
133
134     def rows(self):
135         cursor = self.cursor()
136         start = time.time()
137         while True:
138             row = cursor.fetchone()
139             if not row:
140                 elapsed = time.time() - start
141                 if elapsed > .1:
142                     self.env.log.info(
143                         "query took %f seconds: %s" % (elapsed, self.sql))
144                 raise StopIteration
145             yield row
146     rows = property(rows)
147
148     def objs(self):
149         assert issubclass(self.cls, DbObject)
150         for row in self.rows:
151             kwargs = {}
152             for var, val in zip(row.keys(), row):
153                 kwargs[var] = val
154             obj = self.cls(self.env, self.db, **kwargs)
155             yield obj
156     objs = property(objs)
157
158     def table_name(self):
159         c = self.cls
160         if type(c) in (type(list()), type(tuple())):
161             c = c[0]
162         return c.NAME
163     table_name = property(table_name)
164
165 class InitDB(Component):
166
167     # IEnvironmentSetupParticipant methods
168
169     def environment_needs_upgrade(self, db):
170         cursor = db.cursor()
171         # print self.dbobjects
172         for cls in self.dbobjects:
173             if cls.upgrade_db(self.env, db, noop=True):
174                 return True
175         return False
176
177     def upgrade_environment(self, db):
178         for cls in self.dbobjects:
179             cls.upgrade_db(self.env, db)
180
181
182 class DbObject(dict):
183
184     def __init__(self, env, db, **kwargs):
185         self._env = env
186         self._db = db
187         self._set_defaults()
188         if not kwargs:
189             return
190         for var, val in kwargs.items():
191             if val is None:
192                 continue
193             self[var] = val
194
195     def __str__(self):
196         import pprint
197         return pprint.pformat(self)
198         # return str(self)
199
200     def _set_defaults(self):
201         pass
202
203     def __getattr__(self, name):
204         if name in self.COLUMNS:
205             return self.get(name, None)
206         raise AttributeError, "column %s not in table %s" % (name, self.NAME)
207
208     def chattr(self, **kwargs):
209         for name in kwargs:
210             if name not in self.COLUMNS:
211                 raise AttributeError
212         super(DbObject, self).update(kwargs)
213
214     def upgrade_db(cls, env, db, noop=False):
215         version = cls.VERSION
216         upgrade = False
217         cursor = db.cursor()
218         cursor.execute("SELECT value FROM system "
219                        "WHERE name='market_%s_db_version'" % cls.__name__)
220         if cursor.rowcount != 1:
221             db_version = 0
222         else:
223             db_version = int(cursor.fetchone()[0])
224         if db_version < cls.VERSION:
225             upgrade = True
226         if noop:
227             return upgrade
228         for i in range(db_version, version+1):
229             method = "_upgrade_to_v%d" % i
230             if hasattr(cls, method):
231                 getattr(cls, method)(env, db)
232                 env.log.debug("Done upgrading %s to version %d" %
233                         (cls.__name__, i))
234         cursor.execute("DELETE FROM system "
235                        "WHERE name='market_%s_db_version'" % cls.__name__)
236         cursor.execute("INSERT INTO system (name, value) "
237                        "VALUES ('market_%s_db_version', %%s)" % cls.__name__,
238                        (version, ))
239         env.log.debug("Db upgraded: %s" % cls.__name__)
240     upgrade_db = classmethod(upgrade_db)
241
242     def _upgrade_to_v1(cls, env, db):
243         table = cls.TABLE
244         cursor = db.cursor()
245         for stmt in db.to_sql(table):
246             cursor.execute(stmt)
247         env.log.debug("Done creating table %s" % table.name)
248     _upgrade_to_v1 = classmethod(_upgrade_to_v1)
249
250     def reload(self):
251         '''return a duplicate object from db'''
252         # XXX assumes id is primary key (or at least unique)
253         return self.__class__.selectone(self._env, self._db, id=self.id)
254
255     def select(cls, env, db, append=None, parms=None, **kwargs):
256         s = Select(env, db, cls, append=append, parms=parms, **kwargs)
257         return s.objs
258     select = classmethod(select)
259
260     def selectone(cls, env, db, *args, **kwargs):
261         s = cls.select(env, db, *args, **kwargs)
262         try:
263             obj = s.next()
264         except StopIteration:
265             obj = None
266         return obj
267     selectone = classmethod(selectone)
268
269     def _ck_values(self):
270         """Check values before update.  Raise ValueError if bad value."""
271         pass
272
273     def _ignore(self, *args):
274         for var in self.COLUMNS:
275             if var in args:
276                 continue
277             self._assert(var)
278
279     def _assert(self, *args):
280         for var in args:
281             if self.get(var, None) is None:
282                 raise ValueError("missing value: %s" % var)
283
284     def insert(self, db=None, cursor=None):
285         if not db:
286             db = self._db
287         if not cursor:
288             cursor = db.cursor()
289         self._ck_values()
290         sql = "INSERT INTO %s " % self.NAME
291         colnames = []
292         format = []
293         parms = []
294         for var, val in self.items():
295             colnames.append(var)
296             format.append('%s')
297             parms.append(val)
298         sql += "(%s) " % ', '.join(colnames)
299         sql += "VALUES (%s) " % ', '.join(format)
300         # print sql, parms
301         cursor.execute(sql, parms)
302         if cursor.rowcount != 1:
303             return None
304         return self
305
306     def update(self, set=None, parms=None, **kwargs):
307         db = self._db
308         cursor = db.cursor()
309         sql = "UPDATE %s SET " % self.NAME
310         if not set:
311             set = []
312         if not parms:
313             parms = []
314         for var, val in kwargs.items():
315             self[var]=val
316             set.append('%s=%%s' % var)
317             parms.append(val)
318         sql += "%s " % ', '.join(set)
319         kw = {}
320         # update this record only
321         for col in self.TABLE.key:
322             kw[col] = self[col]
323         where, whereparms = kw2sql(**kw)
324         sql += "WHERE " + where
325         parms += whereparms
326         # print sql, parms
327         self._ck_values()
328         cursor.execute(sql, parms)
329         return cursor.rowcount
330
331     def delete(self):
332         db = self._db
333         cursor = db.cursor()
334         sql = "DELETE FROM %s WHERE " % self.NAME
335         parms = []
336         where = []
337         # delete this record only
338         for col in self.TABLE.key:
339             where.append("%s=%%s" % col)
340             parms.append(self[col])
341         sql += "%s " % ' AND '.join(where)
342         # print sql, parms
343         cursor.execute(sql, parms)
344         return cursor.rowcount
345
346
347 def uniq(set):
348     u = {}
349     for s in set:
350         u[s] = 1
351     return u.keys()
Note: See TracBrowser for help on using the browser.