1 package com.imcode.db.mock;
2
3 import com.imcode.db.Database;
4 import com.imcode.db.DatabaseCommand;
5 import com.imcode.db.DatabaseException;
6 import junit.framework.Assert;
7 import org.apache.commons.collections.CollectionUtils;
8 import org.apache.commons.collections.Predicate;
9 import org.apache.commons.dbutils.ResultSetHandler;
10 import org.apache.commons.lang.ArrayUtils;
11 import org.apache.commons.lang.StringUtils;
12
13 import java.sql.ResultSet;
14 import java.sql.SQLException;
15 import java.util.ArrayList;
16 import java.util.Arrays;
17 import java.util.Iterator;
18 import java.util.List;
19 import java.util.Map;
20 import java.util.regex.Matcher;
21 import java.util.regex.Pattern;
22
23 public class MockDatabase implements Database {
24
25 private List sqlCalls = new ArrayList();
26 private List expectedSqlCalls = new ArrayList();
27
28 public int executeUpdate(String sqlStr, Object[] parameters) {
29 getResultForSqlCall(sqlStr, parameters);
30 return 0;
31 }
32
33
34 public Object executeQuery(String sqlQuery, Object[] parameters, ResultSetHandler resultSetHandler) {
35 ResultSet resultSet = (ResultSet) getResultForSqlCall(sqlQuery, parameters);
36 if (null == resultSet ) {
37 resultSet = new MockResultSet(new Object[0][]) ;
38 }
39 try {
40 return resultSetHandler.handle(resultSet) ;
41 } catch ( SQLException e ) {
42 throw DatabaseException.fromSQLException("", e);
43 }
44 }
45
46 public Object execute(DatabaseCommand databaseCommand) throws DatabaseException {
47 return databaseCommand.executeOn(new MockDatabaseConnection(this));
48 }
49
50 public Object executeCommand(DatabaseCommand databaseCommand) throws DatabaseException {
51 return execute(databaseCommand);
52 }
53
54 public void addExpectedSqlCall(final SqlCallPredicate sqlCallPredicate, final Object result) {
55 expectedSqlCalls.add(new Map.Entry() {
56 public Object getKey() {
57 return sqlCallPredicate;
58 }
59
60 public Object getValue() {
61 return result;
62 }
63
64 public Object setValue(Object value) {
65 throw new UnsupportedOperationException();
66 }
67
68 public String toString() {
69 return sqlCallPredicate + ": " + result;
70 }
71 });
72 }
73
74 public void assertExpectedSqlCalls() {
75 if (!expectedSqlCalls.isEmpty()) {
76 Assert.fail("Remaining expected sql calls: " + expectedSqlCalls.toString());
77 }
78 }
79
80 public int getSqlCallCount() {
81 return sqlCalls.size();
82 }
83
84 Object getResultForSqlCall(String sql, Object[] params) {
85 SqlCall sqlCall = new SqlCall(sql, params);
86 sqlCalls.add(sqlCall);
87 Object result = null;
88 if (!expectedSqlCalls.isEmpty()) {
89 Map.Entry entry = (Map.Entry) expectedSqlCalls.get(0);
90 SqlCallPredicate predicate = (SqlCallPredicate) entry.getKey();
91 if (predicate.evaluateSqlCall(sqlCall)) {
92 result = entry.getValue();
93 expectedSqlCalls.remove(0);
94 }
95 }
96 return result;
97 }
98
99 public static class SqlCall {
100
101 private String string;
102 private Object[] parameters;
103
104 public SqlCall(String string, Object[] parameters) {
105 this.string = string;
106 this.parameters = parameters;
107 }
108
109 public String getString() {
110 return string;
111 }
112
113 public Object[] getParameters() {
114 return parameters;
115 }
116
117 public String toString() {
118 return getString() + " " + StringUtils.join(getParameters(), ", ");
119 }
120
121 }
122
123 public void assertCalled(SqlCallPredicate predicate) {
124 assertCalled(null, predicate);
125 }
126
127 public void assertCalledInOrder(SqlCallPredicate[] sqlCallPredicates) {
128 int sqlCallPredicatesIndex = 0 ;
129 for ( Iterator iterator = sqlCalls.iterator(); iterator.hasNext(); ) {
130 SqlCall sqlCall = (SqlCall) iterator.next();
131 if (sqlCallPredicates[sqlCallPredicatesIndex].evaluateSqlCall(sqlCall)) {
132 sqlCallPredicatesIndex++ ;
133 if (sqlCallPredicatesIndex == sqlCallPredicates.length) {
134 break ;
135 }
136 }
137 }
138 if (sqlCallPredicatesIndex < sqlCallPredicates.length) {
139 String failureMessage = "Expected sql call \"" + sqlCallPredicates[sqlCallPredicatesIndex].getFailureMessage()+"\"";
140 if (sqlCallPredicatesIndex > 0) {
141 failureMessage += " after sql call \""+sqlCallPredicates[sqlCallPredicatesIndex-1]+"\"" ;
142 }
143 Assert.fail(failureMessage) ;
144 }
145 }
146
147 public void assertCalled(String message, SqlCallPredicate predicate) {
148 if (!called(predicate)) {
149 String messagePrefix = null == message ? "" : message + " ";
150 Assert.fail(messagePrefix + "Expected at least one sql call: " + predicate.getFailureMessage());
151 }
152 }
153
154 private boolean called(SqlCallPredicate predicate) {
155 return CollectionUtils.exists(sqlCalls, predicate);
156 }
157
158 public void assertNotCalled(SqlCallPredicate sqlCallPredicate) {
159 assertNotCalled(null, sqlCallPredicate);
160 }
161
162 public void assertNotCalled(String message, SqlCallPredicate predicate) {
163 if (called(predicate)) {
164 String messagePrefix = null == message ? "" : message + " ";
165 Assert.fail(messagePrefix + "Got unexpected sql call: " + predicate.getFailureMessage());
166 }
167 }
168
169 public void assertCallCount(int expectedCount, SqlCallPredicate predicate) {
170 int actualCount = CollectionUtils.countMatches(sqlCalls, predicate);
171 if (expectedCount != actualCount) {
172 Assert.fail("Expected " + expectedCount + ", but got " + actualCount + " sql calls: " + predicate.getFailureMessage());
173 }
174 }
175
176 public abstract static class SqlCallPredicate implements Predicate {
177
178 public final boolean evaluate(Object object) {
179 return evaluateSqlCall((MockDatabase.SqlCall) object);
180 }
181
182 abstract boolean evaluateSqlCall(MockDatabase.SqlCall sqlCall);
183
184 abstract String getFailureMessage();
185
186 public String toString() {
187 return getFailureMessage();
188 }
189 }
190
191 public static class UpdateTableSqlCallPredicate extends SqlCallPredicate {
192
193 private String tableName;
194 private Object parameter;
195
196 public UpdateTableSqlCallPredicate(String tableName, Object parameter) {
197 this.tableName = tableName;
198 this.parameter = parameter;
199 }
200
201 boolean evaluateSqlCall(MockDatabase.SqlCall sqlCall) {
202 boolean stringMatchesUpdateTableName = Pattern.compile("^update//s+//b" + tableName+"//b").matcher(sqlCall.getString().toLowerCase()).find();
203 boolean parametersContainsParameter = ArrayUtils.contains(sqlCall.getParameters(), parameter);
204 return stringMatchesUpdateTableName && parametersContainsParameter;
205 }
206
207 String getFailureMessage() {
208 return "update of table " + tableName + " with one parameter = " + parameter;
209 }
210 }
211
212 public static class InsertIntoTableSqlCallPredicate extends MatchesRegexSqlCallPredicate {
213
214 private String tableName;
215
216 public InsertIntoTableSqlCallPredicate(String tableName) {
217 super("^insert//s+(?:into//s+)?//b" + tableName+"//b") ;
218 this.tableName = tableName;
219 }
220
221 String getFailureMessage() {
222 return "insert into table " + tableName ;
223 }
224 }
225
226 public static class InsertIntoTableWithParameterSqlCallPredicate extends InsertIntoTableSqlCallPredicate {
227
228 private String parameter;
229
230 public InsertIntoTableWithParameterSqlCallPredicate(String tableName, String parameter) {
231 super(tableName);
232 this.parameter = parameter;
233 }
234
235 boolean evaluateSqlCall(MockDatabase.SqlCall sqlCall) {
236 return super.evaluateSqlCall(sqlCall) && ArrayUtils.contains(sqlCall.getParameters(), parameter);
237 }
238
239 String getFailureMessage() {
240 return super.getFailureMessage() + " with one parameter = \"" + parameter + "\"";
241 }
242 }
243
244 public static class MatchesRegexSqlCallPredicate extends SqlCallPredicate {
245
246 private String regex;
247
248 public MatchesRegexSqlCallPredicate(String regex) {
249 this.regex = regex;
250 }
251
252 boolean evaluateSqlCall(MockDatabase.SqlCall sqlCall) {
253 Pattern pattern = Pattern.compile(regex, Pattern.CASE_INSENSITIVE);
254 Matcher matcher = pattern.matcher(sqlCall.getString());
255 return matcher.find();
256 }
257
258 String getFailureMessage() {
259 return "Expected call to match regex " + regex;
260 }
261 }
262
263 public static class EqualsSqlCallPredicate extends SqlCallPredicate {
264
265 String sql;
266
267 public EqualsSqlCallPredicate(String sql) {
268 this.sql = sql;
269 }
270
271 boolean evaluateSqlCall(SqlCall sqlCall) {
272 return sql.equalsIgnoreCase(sqlCall.getString());
273 }
274
275 String getFailureMessage() {
276 return "sql \"" + sql + "\"";
277 }
278 }
279
280 public static class StartsWithSqlCallPredicate extends SqlCallPredicate {
281
282 private String prefix;
283
284 public StartsWithSqlCallPredicate(String prefix) {
285 this.prefix = prefix;
286 }
287
288 boolean evaluateSqlCall(SqlCall sqlCall) {
289 return sqlCall.getString().startsWith(prefix);
290 }
291
292 String getFailureMessage() {
293 return "start with " + prefix;
294 }
295 }
296
297 public static class EqualsWithParametersSqlCallPredicate extends EqualsSqlCallPredicate {
298
299 private String[] parameters;
300
301 public EqualsWithParametersSqlCallPredicate(String sql, String[] parameters) {
302 super(sql);
303 this.parameters = parameters;
304 }
305
306 boolean evaluateSqlCall(SqlCall sqlCall) {
307 return super.evaluateSqlCall(sqlCall) && Arrays.equals(parameters, sqlCall.getParameters());
308 }
309
310 String getFailureMessage() {
311 return super.getFailureMessage() + " with parameters " + ArrayUtils.toString(parameters);
312 }
313 }
314
315 public static class DeleteFromTableSqlCallPredicate extends MatchesRegexSqlCallPredicate {
316
317 private String tableName;
318
319 public DeleteFromTableSqlCallPredicate(String tableName) {
320 super("^delete//s+from//s+//b" + tableName+"//b") ;
321 this.tableName = tableName;
322 }
323
324 String getFailureMessage() {
325 return "delete from "+tableName;
326 }
327
328 }
329 }