package com.taobao.dbunit.dao;
import java.sql.SQLException;
import javax.sql.DataSource;
import org.dbunit.Assertion;
import org.dbunit.database.DatabaseConnection;
import org.dbunit.database.IDatabaseConnection;
import org.dbunit.dataset.DataSetException;
import org.dbunit.dataset.DefaultDataSet;
import org.dbunit.dataset.DefaultTable;
import org.dbunit.dataset.IDataSet;
import org.dbunit.dataset.xml.FlatXmlDataSet;
import org.dbunit.operation.DatabaseOperation;
import org.junit.Assert;
import org.junit.Before;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.io.ClassPathResource;
import org.springframework.jdbc.datasource.DataSourceUtils;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.junit4.AbstractTransactionalJUnit4SpringContextTests;
import org.springframework.test.context.transaction.TransactionConfiguration;
@ContextConfiguration(locations = { "classpath:testApplicationContext.xml" })
@TransactionConfiguration(defaultRollback = true)
public class BaseDaoTest extends AbstractTransactionalJUnit4SpringContextTests {
@Autowired
private DataSource dataSource;
private IDatabaseConnection conn;
@Before
public void initDbunit() throws Exception {
conn = new DatabaseConnection(DataSourceUtils.getConnection(dataSource));
}
/**
* 清空file中包含的表中的数据,并插入file中指定的数据
*
* @param file
* @throws Exception
*/
protected void setUpDataSet(String file) throws Exception {
IDataSet dataset = new FlatXmlDataSet(new ClassPathResource(file)
.getFile());
DatabaseOperation.CLEAN_INSERT.execute(conn, dataset);
}
/**
* 验证file中包含的表中的数据和数据库中的相应表的数据是否一致
*
* @param file
* @throws Exception
*/
protected void verifyDataSet(String file) throws Exception {
IDataSet expected = new FlatXmlDataSet(new ClassPathResource(file)
.getFile());
IDataSet dataset = conn.createDataSet();
for (String tableName : expected.getTableNames()) {
Assertion.assertEquals(expected.getTable(tableName), dataset
.getTable(tableName));
}
}
/**
* 清空指定的表中的数据
*
* @param tableName
* @throws Exception
*/
protected void clearTable(String tableName) throws Exception {
DefaultDataSet dataset = new DefaultDataSet();
dataset.addTable(new DefaultTable(tableName));
DatabaseOperation.DELETE_ALL.execute(conn, dataset);
}
/**
* 验证指定的表为空
*
* @param tableName
* @throws DataSetException
* @throws SQLException
*/
protected void verifyEmpty(String tableName) throws DataSetException,
SQLException {
Assert.assertEquals(0, conn.createDataSet().getTable(tableName)
.getRowCount());
}
}