001 /**
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements. See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership. The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License. You may obtain a copy of the License at
009 *
010 * http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software
013 * distributed under the License is distributed on an "AS IS" BASIS,
014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015 * See the License for the specific language governing permissions and
016 * limitations under the License.
017 */
018
019 package org.apache.hadoop.mapreduce.lib.db;
020
021 import java.io.DataInput;
022 import java.io.DataOutput;
023 import java.io.IOException;
024 import java.sql.Connection;
025 import java.sql.DatabaseMetaData;
026 import java.sql.PreparedStatement;
027 import java.sql.ResultSet;
028 import java.sql.SQLException;
029 import java.sql.Statement;
030 import java.util.ArrayList;
031 import java.util.List;
032
033 import org.apache.hadoop.io.LongWritable;
034 import org.apache.hadoop.io.Writable;
035 import org.apache.hadoop.mapreduce.InputFormat;
036 import org.apache.hadoop.mapreduce.InputSplit;
037 import org.apache.hadoop.mapreduce.Job;
038 import org.apache.hadoop.mapreduce.JobContext;
039 import org.apache.hadoop.mapreduce.MRJobConfig;
040 import org.apache.hadoop.mapreduce.RecordReader;
041 import org.apache.hadoop.mapreduce.TaskAttemptContext;
042 import org.apache.hadoop.util.ReflectionUtils;
043 import org.apache.hadoop.classification.InterfaceAudience;
044 import org.apache.hadoop.classification.InterfaceStability;
045 import org.apache.hadoop.conf.Configurable;
046 import org.apache.hadoop.conf.Configuration;
047 /**
048 * A InputFormat that reads input data from an SQL table.
049 * <p>
050 * DBInputFormat emits LongWritables containing the record number as
051 * key and DBWritables as value.
052 *
053 * The SQL query, and input class can be using one of the two
054 * setInput methods.
055 */
056 @InterfaceAudience.Public
057 @InterfaceStability.Stable
058 public class DBInputFormat<T extends DBWritable>
059 extends InputFormat<LongWritable, T> implements Configurable {
060
061 private String dbProductName = "DEFAULT";
062
063 /**
064 * A Class that does nothing, implementing DBWritable
065 */
066 @InterfaceStability.Evolving
067 public static class NullDBWritable implements DBWritable, Writable {
068 @Override
069 public void readFields(DataInput in) throws IOException { }
070 @Override
071 public void readFields(ResultSet arg0) throws SQLException { }
072 @Override
073 public void write(DataOutput out) throws IOException { }
074 @Override
075 public void write(PreparedStatement arg0) throws SQLException { }
076 }
077
078 /**
079 * A InputSplit that spans a set of rows
080 */
081 @InterfaceStability.Evolving
082 public static class DBInputSplit extends InputSplit implements Writable {
083
084 private long end = 0;
085 private long start = 0;
086
087 /**
088 * Default Constructor
089 */
090 public DBInputSplit() {
091 }
092
093 /**
094 * Convenience Constructor
095 * @param start the index of the first row to select
096 * @param end the index of the last row to select
097 */
098 public DBInputSplit(long start, long end) {
099 this.start = start;
100 this.end = end;
101 }
102
103 /** {@inheritDoc} */
104 public String[] getLocations() throws IOException {
105 // TODO Add a layer to enable SQL "sharding" and support locality
106 return new String[] {};
107 }
108
109 /**
110 * @return The index of the first row to select
111 */
112 public long getStart() {
113 return start;
114 }
115
116 /**
117 * @return The index of the last row to select
118 */
119 public long getEnd() {
120 return end;
121 }
122
123 /**
124 * @return The total row count in this split
125 */
126 public long getLength() throws IOException {
127 return end - start;
128 }
129
130 /** {@inheritDoc} */
131 public void readFields(DataInput input) throws IOException {
132 start = input.readLong();
133 end = input.readLong();
134 }
135
136 /** {@inheritDoc} */
137 public void write(DataOutput output) throws IOException {
138 output.writeLong(start);
139 output.writeLong(end);
140 }
141 }
142
143 private String conditions;
144
145 private Connection connection;
146
147 private String tableName;
148
149 private String[] fieldNames;
150
151 private DBConfiguration dbConf;
152
153 /** {@inheritDoc} */
154 public void setConf(Configuration conf) {
155
156 dbConf = new DBConfiguration(conf);
157
158 try {
159 getConnection();
160
161 DatabaseMetaData dbMeta = connection.getMetaData();
162 this.dbProductName = dbMeta.getDatabaseProductName().toUpperCase();
163 }
164 catch (Exception ex) {
165 throw new RuntimeException(ex);
166 }
167
168 tableName = dbConf.getInputTableName();
169 fieldNames = dbConf.getInputFieldNames();
170 conditions = dbConf.getInputConditions();
171 }
172
173 public Configuration getConf() {
174 return dbConf.getConf();
175 }
176
177 public DBConfiguration getDBConf() {
178 return dbConf;
179 }
180
181 public Connection getConnection() {
182 try {
183 if (null == this.connection) {
184 // The connection was closed; reinstantiate it.
185 this.connection = dbConf.getConnection();
186 this.connection.setAutoCommit(false);
187 this.connection.setTransactionIsolation(
188 Connection.TRANSACTION_SERIALIZABLE);
189 }
190 } catch (Exception e) {
191 throw new RuntimeException(e);
192 }
193 return connection;
194 }
195
196 public String getDBProductName() {
197 return dbProductName;
198 }
199
200 protected RecordReader<LongWritable, T> createDBRecordReader(DBInputSplit split,
201 Configuration conf) throws IOException {
202
203 @SuppressWarnings("unchecked")
204 Class<T> inputClass = (Class<T>) (dbConf.getInputClass());
205 try {
206 // use database product name to determine appropriate record reader.
207 if (dbProductName.startsWith("ORACLE")) {
208 // use Oracle-specific db reader.
209 return new OracleDBRecordReader<T>(split, inputClass,
210 conf, getConnection(), getDBConf(), conditions, fieldNames,
211 tableName);
212 } else if (dbProductName.startsWith("MYSQL")) {
213 // use MySQL-specific db reader.
214 return new MySQLDBRecordReader<T>(split, inputClass,
215 conf, getConnection(), getDBConf(), conditions, fieldNames,
216 tableName);
217 } else {
218 // Generic reader.
219 return new DBRecordReader<T>(split, inputClass,
220 conf, getConnection(), getDBConf(), conditions, fieldNames,
221 tableName);
222 }
223 } catch (SQLException ex) {
224 throw new IOException(ex.getMessage());
225 }
226 }
227
228 /** {@inheritDoc} */
229 @SuppressWarnings("unchecked")
230 public RecordReader<LongWritable, T> createRecordReader(InputSplit split,
231 TaskAttemptContext context) throws IOException, InterruptedException {
232
233 return createDBRecordReader((DBInputSplit) split, context.getConfiguration());
234 }
235
236 /** {@inheritDoc} */
237 public List<InputSplit> getSplits(JobContext job) throws IOException {
238
239 ResultSet results = null;
240 Statement statement = null;
241 try {
242 statement = connection.createStatement();
243
244 results = statement.executeQuery(getCountQuery());
245 results.next();
246
247 long count = results.getLong(1);
248 int chunks = job.getConfiguration().getInt(MRJobConfig.NUM_MAPS, 1);
249 long chunkSize = (count / chunks);
250
251 results.close();
252 statement.close();
253
254 List<InputSplit> splits = new ArrayList<InputSplit>();
255
256 // Split the rows into n-number of chunks and adjust the last chunk
257 // accordingly
258 for (int i = 0; i < chunks; i++) {
259 DBInputSplit split;
260
261 if ((i + 1) == chunks)
262 split = new DBInputSplit(i * chunkSize, count);
263 else
264 split = new DBInputSplit(i * chunkSize, (i * chunkSize)
265 + chunkSize);
266
267 splits.add(split);
268 }
269
270 connection.commit();
271 return splits;
272 } catch (SQLException e) {
273 throw new IOException("Got SQLException", e);
274 } finally {
275 try {
276 if (results != null) { results.close(); }
277 } catch (SQLException e1) {}
278 try {
279 if (statement != null) { statement.close(); }
280 } catch (SQLException e1) {}
281
282 closeConnection();
283 }
284 }
285
286 /** Returns the query for getting the total number of rows,
287 * subclasses can override this for custom behaviour.*/
288 protected String getCountQuery() {
289
290 if(dbConf.getInputCountQuery() != null) {
291 return dbConf.getInputCountQuery();
292 }
293
294 StringBuilder query = new StringBuilder();
295 query.append("SELECT COUNT(*) FROM " + tableName);
296
297 if (conditions != null && conditions.length() > 0)
298 query.append(" WHERE " + conditions);
299 return query.toString();
300 }
301
302 /**
303 * Initializes the map-part of the job with the appropriate input settings.
304 *
305 * @param job The map-reduce job
306 * @param inputClass the class object implementing DBWritable, which is the
307 * Java object holding tuple fields.
308 * @param tableName The table to read data from
309 * @param conditions The condition which to select data with,
310 * eg. '(updated > 20070101 AND length > 0)'
311 * @param orderBy the fieldNames in the orderBy clause.
312 * @param fieldNames The field names in the table
313 * @see #setInput(Job, Class, String, String)
314 */
315 public static void setInput(Job job,
316 Class<? extends DBWritable> inputClass,
317 String tableName,String conditions,
318 String orderBy, String... fieldNames) {
319 job.setInputFormatClass(DBInputFormat.class);
320 DBConfiguration dbConf = new DBConfiguration(job.getConfiguration());
321 dbConf.setInputClass(inputClass);
322 dbConf.setInputTableName(tableName);
323 dbConf.setInputFieldNames(fieldNames);
324 dbConf.setInputConditions(conditions);
325 dbConf.setInputOrderBy(orderBy);
326 }
327
328 /**
329 * Initializes the map-part of the job with the appropriate input settings.
330 *
331 * @param job The map-reduce job
332 * @param inputClass the class object implementing DBWritable, which is the
333 * Java object holding tuple fields.
334 * @param inputQuery the input query to select fields. Example :
335 * "SELECT f1, f2, f3 FROM Mytable ORDER BY f1"
336 * @param inputCountQuery the input query that returns
337 * the number of records in the table.
338 * Example : "SELECT COUNT(f1) FROM Mytable"
339 * @see #setInput(Job, Class, String, String, String, String...)
340 */
341 public static void setInput(Job job,
342 Class<? extends DBWritable> inputClass,
343 String inputQuery, String inputCountQuery) {
344 job.setInputFormatClass(DBInputFormat.class);
345 DBConfiguration dbConf = new DBConfiguration(job.getConfiguration());
346 dbConf.setInputClass(inputClass);
347 dbConf.setInputQuery(inputQuery);
348 dbConf.setInputCountQuery(inputCountQuery);
349 }
350
351 protected void closeConnection() {
352 try {
353 if (null != this.connection) {
354 this.connection.close();
355 this.connection = null;
356 }
357 } catch (SQLException sqlE) { } // ignore exception on close.
358 }
359 }