1 package com.imcode.db.mock;
2
3 import com.imcode.db.AbstractDatabase;
4 import com.imcode.db.DatabaseCommand;
5 import com.imcode.db.DatabaseException;
6 import junit.framework.Assert;
7 import org.apache.commons.collections.CollectionUtils;
8 import org.apache.commons.collections.Predicate;
9 import org.apache.commons.dbutils.ResultSetHandler;
10 import org.apache.commons.lang.ArrayUtils;
11 import org.apache.commons.lang.StringUtils;
12
13 import java.sql.ResultSet;
14 import java.sql.SQLException;
15 import java.util.ArrayList;
16 import java.util.Arrays;
17 import java.util.Iterator;
18 import java.util.List;
19 import java.util.Map;
20 import java.util.regex.Matcher;
21 import java.util.regex.Pattern;
22
23 public class MockDatabase extends AbstractDatabase {
24
25 private List sqlCalls = new ArrayList();
26 private List expectedSqlCalls = new ArrayList();
27
28 public int executeUpdate(String sqlStr, Object[] parameters) {
29 getResultForSqlCall(sqlStr, parameters);
30 return 0;
31 }
32
33
34 public Object executeQuery(String sqlQuery, Object[] parameters, ResultSetHandler resultSetHandler) {
35 ResultSet resultSet = (ResultSet) getResultForSqlCall(sqlQuery, parameters);
36 if (null == resultSet ) {
37 resultSet = new MockResultSet(new Object[0][]) ;
38 }
39 try {
40 return resultSetHandler.handle(resultSet) ;
41 } catch ( SQLException e ) {
42 throw DatabaseException.fromSQLException("", e);
43 }
44 }
45
46 public Object execute(DatabaseCommand databaseCommand) throws DatabaseException {
47 return databaseCommand.executeOn(new MockDatabaseConnection(this));
48 }
49
50 public void addExpectedSqlCall(final SqlCallPredicate sqlCallPredicate, final Object result) {
51 expectedSqlCalls.add(new Map.Entry() {
52 public Object getKey() {
53 return sqlCallPredicate;
54 }
55
56 public Object getValue() {
57 return result;
58 }
59
60 public Object setValue(Object value) {
61 throw new UnsupportedOperationException();
62 }
63
64 public String toString() {
65 return sqlCallPredicate + ": " + result;
66 }
67 });
68 }
69
70 public void assertExpectedSqlCalls() {
71 if (!expectedSqlCalls.isEmpty()) {
72 Assert.fail("Remaining expected sql calls: " + expectedSqlCalls.toString());
73 }
74 }
75
76 public int getSqlCallCount() {
77 return sqlCalls.size();
78 }
79
80 Object getResultForSqlCall(String sql, Object[] params) {
81 SqlCall sqlCall = new SqlCall(sql, params);
82 sqlCalls.add(sqlCall);
83 Object result = null;
84 if (!expectedSqlCalls.isEmpty()) {
85 Map.Entry entry = (Map.Entry) expectedSqlCalls.get(0);
86 SqlCallPredicate predicate = (SqlCallPredicate) entry.getKey();
87 if (predicate.evaluateSqlCall(sqlCall)) {
88 result = entry.getValue();
89 expectedSqlCalls.remove(0);
90 }
91 }
92 return result;
93 }
94
95 public static class SqlCall {
96
97 private String string;
98 private Object[] parameters;
99
100 public SqlCall(String string, Object[] parameters) {
101 this.string = string;
102 this.parameters = parameters;
103 }
104
105 public String getString() {
106 return string;
107 }
108
109 public Object[] getParameters() {
110 return parameters;
111 }
112
113 public String toString() {
114 return getString() + " " + StringUtils.join(getParameters(), ", ");
115 }
116
117 }
118
119 public void assertCalled(SqlCallPredicate predicate) {
120 assertCalled(null, predicate);
121 }
122
123 public void assertCalledInOrder(SqlCallPredicate[] sqlCallPredicates) {
124 int sqlCallPredicatesIndex = 0 ;
125 for ( Iterator iterator = sqlCalls.iterator(); iterator.hasNext(); ) {
126 SqlCall sqlCall = (SqlCall) iterator.next();
127 if (sqlCallPredicates[sqlCallPredicatesIndex].evaluateSqlCall(sqlCall)) {
128 sqlCallPredicatesIndex++ ;
129 if (sqlCallPredicatesIndex == sqlCallPredicates.length) {
130 break ;
131 }
132 }
133 }
134 if (sqlCallPredicatesIndex < sqlCallPredicates.length) {
135 String failureMessage = "Expected sql call \"" + sqlCallPredicates[sqlCallPredicatesIndex].getFailureMessage()+"\"";
136 if (sqlCallPredicatesIndex > 0) {
137 failureMessage += " after sql call \""+sqlCallPredicates[sqlCallPredicatesIndex-1]+"\"" ;
138 }
139 Assert.fail(failureMessage) ;
140 }
141 }
142
143 public void assertCalled(String message, SqlCallPredicate predicate) {
144 if (!called(predicate)) {
145 String messagePrefix = null == message ? "" : message + " ";
146 Assert.fail(messagePrefix + "Expected at least one sql call: " + predicate.getFailureMessage());
147 }
148 }
149
150 private boolean called(SqlCallPredicate predicate) {
151 return CollectionUtils.exists(sqlCalls, predicate);
152 }
153
154 public void assertNotCalled(SqlCallPredicate sqlCallPredicate) {
155 assertNotCalled(null, sqlCallPredicate);
156 }
157
158 public void assertNotCalled(String message, SqlCallPredicate predicate) {
159 if (called(predicate)) {
160 String messagePrefix = null == message ? "" : message + " ";
161 Assert.fail(messagePrefix + "Got unexpected sql call: " + predicate.getFailureMessage());
162 }
163 }
164
165 public void assertCallCount(int expectedCount, SqlCallPredicate predicate) {
166 int actualCount = CollectionUtils.countMatches(sqlCalls, predicate);
167 if (expectedCount != actualCount) {
168 Assert.fail("Expected " + expectedCount + ", but got " + actualCount + " sql calls: " + predicate.getFailureMessage());
169 }
170 }
171
172 public abstract static class SqlCallPredicate implements Predicate {
173
174 public final boolean evaluate(Object object) {
175 return evaluateSqlCall((MockDatabase.SqlCall) object);
176 }
177
178 abstract boolean evaluateSqlCall(MockDatabase.SqlCall sqlCall);
179
180 abstract String getFailureMessage();
181
182 public String toString() {
183 return getFailureMessage();
184 }
185 }
186
187 public static class UpdateTableSqlCallPredicate extends SqlCallPredicate {
188
189 private String tableName;
190 private Object parameter;
191
192 public UpdateTableSqlCallPredicate(String tableName, Object parameter) {
193 this.tableName = tableName;
194 this.parameter = parameter;
195 }
196
197 boolean evaluateSqlCall(MockDatabase.SqlCall sqlCall) {
198 boolean stringMatchesUpdateTableName = Pattern.compile("^update//s+//b" + tableName+"//b").matcher(sqlCall.getString().toLowerCase()).find();
199 boolean parametersContainsParameter = ArrayUtils.contains(sqlCall.getParameters(), parameter);
200 return stringMatchesUpdateTableName && parametersContainsParameter;
201 }
202
203 String getFailureMessage() {
204 return "update of table " + tableName + " with one parameter = " + parameter;
205 }
206 }
207
208 public static class InsertIntoTableSqlCallPredicate extends MatchesRegexSqlCallPredicate {
209
210 private String tableName;
211
212 public InsertIntoTableSqlCallPredicate(String tableName) {
213 super("^insert//s+(?:into//s+)?//b" + tableName+"//b") ;
214 this.tableName = tableName;
215 }
216
217 String getFailureMessage() {
218 return "insert into table " + tableName ;
219 }
220 }
221
222 public static class InsertIntoTableWithParameterSqlCallPredicate extends InsertIntoTableSqlCallPredicate {
223
224 private String parameter;
225
226 public InsertIntoTableWithParameterSqlCallPredicate(String tableName, String parameter) {
227 super(tableName);
228 this.parameter = parameter;
229 }
230
231 boolean evaluateSqlCall(MockDatabase.SqlCall sqlCall) {
232 return super.evaluateSqlCall(sqlCall) && ArrayUtils.contains(sqlCall.getParameters(), parameter);
233 }
234
235 String getFailureMessage() {
236 return super.getFailureMessage() + " with one parameter = \"" + parameter + "\"";
237 }
238 }
239
240 public static class MatchesRegexSqlCallPredicate extends SqlCallPredicate {
241
242 private String regex;
243
244 public MatchesRegexSqlCallPredicate(String regex) {
245 this.regex = regex;
246 }
247
248 boolean evaluateSqlCall(MockDatabase.SqlCall sqlCall) {
249 Pattern pattern = Pattern.compile(regex, Pattern.CASE_INSENSITIVE);
250 Matcher matcher = pattern.matcher(sqlCall.getString());
251 return matcher.find();
252 }
253
254 String getFailureMessage() {
255 return "Expected call to match regex " + regex;
256 }
257 }
258
259 public static class EqualsSqlCallPredicate extends SqlCallPredicate {
260
261 String sql;
262
263 public EqualsSqlCallPredicate(String sql) {
264 this.sql = sql;
265 }
266
267 boolean evaluateSqlCall(SqlCall sqlCall) {
268 return sql.equalsIgnoreCase(sqlCall.getString());
269 }
270
271 String getFailureMessage() {
272 return "sql \"" + sql + "\"";
273 }
274 }
275
276 public static class StartsWithSqlCallPredicate extends SqlCallPredicate {
277
278 private String prefix;
279
280 public StartsWithSqlCallPredicate(String prefix) {
281 this.prefix = prefix;
282 }
283
284 boolean evaluateSqlCall(SqlCall sqlCall) {
285 return sqlCall.getString().startsWith(prefix);
286 }
287
288 String getFailureMessage() {
289 return "start with " + prefix;
290 }
291 }
292
293 public static class EqualsWithParametersSqlCallPredicate extends EqualsSqlCallPredicate {
294
295 private String[] parameters;
296
297 public EqualsWithParametersSqlCallPredicate(String sql, String[] parameters) {
298 super(sql);
299 this.parameters = parameters;
300 }
301
302 boolean evaluateSqlCall(SqlCall sqlCall) {
303 return super.evaluateSqlCall(sqlCall) && Arrays.equals(parameters, sqlCall.getParameters());
304 }
305
306 String getFailureMessage() {
307 return super.getFailureMessage() + " with parameters " + ArrayUtils.toString(parameters);
308 }
309 }
310
311 public static class DeleteFromTableSqlCallPredicate extends MatchesRegexSqlCallPredicate {
312
313 private String tableName;
314
315 public DeleteFromTableSqlCallPredicate(String tableName) {
316 super("^delete//s+from//s+//b" + tableName+"//b") ;
317 this.tableName = tableName;
318 }
319
320 String getFailureMessage() {
321 return "delete from "+tableName;
322 }
323
324 }
325 }