Created
July 18, 2025 13:20
-
-
Save chrisui/b700330efa63e55a1cab197fc38cf5f7 to your computer and use it in GitHub Desktop.
Better-auth Drizzle adapter that uses schema id's
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
* This module is a copy and extension of the better-auth drizzle-adapter which | |
* allows for generation of id's to be handled by the schema and/or db | |
* | |
* @see https://github.com/better-auth/better-auth/blob/main/packages/better-auth/src/adapters/drizzle-adapter/drizzle-adapter.ts | |
* @see https://github.com/better-auth/better-auth/issues/3445 | |
* @see https://github.com/better-auth/better-auth/issues/2480 | |
*/ | |
/** biome-ignore-all assist/source/organizeImports: tmp, copied from 3rd-party */ | |
/** biome-ignore-all lint/style/useImportType: tmp, copied from 3rd-party */ | |
/** biome-ignore-all lint/style/noUselessElse: tmp, copied from 3rd-party */ | |
/** biome-ignore-all lint/correctness/noUnusedFunctionParameters: tmp, copied from 3rd-party */ | |
/** biome-ignore-all lint/suspicious/noExplicitAny: tmp, copied from 3rd-party */ | |
/** biome-ignore-all lint/style/noNonNullAssertion: tmp, copied from 3rd-party */ | |
/** biome-ignore-all lint/complexity/useOptionalChain: tmp, copied from 3rd-party */ | |
import { | |
and, | |
asc, | |
count, | |
desc, | |
eq, | |
gt, | |
gte, | |
inArray, | |
like, | |
lt, | |
lte, | |
ne, | |
or, | |
sql, | |
SQL, | |
} from 'drizzle-orm'; | |
import { BetterAuthError } from 'better-auth'; | |
import type { Where } from 'better-auth/types'; | |
import { createAdapter, type AdapterDebugLogs } from 'better-auth/adapters'; | |
export interface DB { | |
[key: string]: any; | |
} | |
export interface DrizzleAdapterConfig { | |
/** | |
* The schema object that defines the tables and fields | |
*/ | |
schema?: Record<string, any>; | |
/** | |
* The database provider | |
*/ | |
provider: 'pg' | 'mysql' | 'sqlite'; | |
/** | |
* If the table names in the schema are plural | |
* set this to true. For example, if the schema | |
* has an object with a key "users" instead of "user" | |
*/ | |
usePlural?: boolean; | |
/** | |
* Enable debug logs for the adapter | |
* | |
* @default false | |
*/ | |
debugLogs?: AdapterDebugLogs; | |
/** | |
* Whether to generate IDs for new records | |
* Set to false when using schema-level ID generation (e.g., .$defaultFn(() => 'randomid')) | |
* | |
* @default true | |
*/ | |
generateIds?: boolean; | |
} | |
export const drizzleAdapter = (db: DB, config: DrizzleAdapterConfig) => | |
createAdapter({ | |
config: { | |
adapterId: 'drizzle', | |
adapterName: 'Drizzle Adapter', | |
usePlural: config.usePlural ?? false, | |
debugLogs: config.debugLogs ?? false, | |
}, | |
adapter: ({ getFieldName, debugLog }) => { | |
function getSchema(model: string) { | |
const schema = config.schema || db._.fullSchema; | |
if (!schema) { | |
throw new BetterAuthError( | |
'Drizzle adapter failed to initialize. Schema not found. Please provide a schema object in the adapter options object.', | |
); | |
} | |
const schemaModel = schema[model]; | |
if (!schemaModel) { | |
throw new BetterAuthError( | |
`[# Drizzle Adapter]: The model "${model}" was not found in the schema object. Please pass the schema directly to the adapter options.`, | |
); | |
} | |
return schemaModel; | |
} | |
const withReturning = async ( | |
model: string, | |
builder: any, | |
data: Record<string, any>, | |
where?: Where[], | |
) => { | |
if (config.provider !== 'mysql') { | |
const c = await builder.returning(); | |
return c[0]; | |
} | |
await builder.execute(); | |
const schemaModel = getSchema(model); | |
const builderVal = builder.config?.values; | |
if (where?.length) { | |
const clause = convertWhereClause(where, model); | |
const res = await db | |
.select() | |
.from(schemaModel) | |
.where(...clause); | |
return res[0]; | |
} else if (builderVal && builderVal[0]?.id?.value) { | |
let tId = builderVal[0]?.id?.value; | |
if (!tId) { | |
//get last inserted id | |
const lastInsertId = await db | |
.select({ id: sql`LAST_INSERT_ID()` }) | |
.from(schemaModel) | |
.orderBy(desc(schemaModel.id)) | |
.limit(1); | |
tId = lastInsertId[0].id; | |
} | |
const res = await db | |
.select() | |
.from(schemaModel) | |
.where(eq(schemaModel.id, tId)) | |
.limit(1) | |
.execute(); | |
return res[0]; | |
} else if (data.id) { | |
const res = await db | |
.select() | |
.from(schemaModel) | |
.where(eq(schemaModel.id, data.id)) | |
.limit(1) | |
.execute(); | |
return res[0]; | |
} else { | |
// If the user doesn't have `id` as a field, then this will fail. | |
// We expect that they defined `id` in all of their models. | |
if (!('id' in schemaModel)) { | |
throw new BetterAuthError( | |
`The model "${model}" does not have an "id" field. Please use the "id" field as your primary key.`, | |
); | |
} | |
const res = await db | |
.select() | |
.from(schemaModel) | |
.orderBy(desc(schemaModel.id)) | |
.limit(1) | |
.execute(); | |
return res[0]; | |
} | |
}; | |
function convertWhereClause(where: Where[], model: string) { | |
const schemaModel = getSchema(model); | |
if (!where) return []; | |
if (where.length === 1) { | |
const w = where[0]; | |
if (!w) { | |
return []; | |
} | |
const field = getFieldName({ model, field: w.field }); | |
if (!schemaModel[field]) { | |
throw new BetterAuthError( | |
`The field "${w.field}" does not exist in the schema for the model "${model}". Please update your schema.`, | |
); | |
} | |
if (w.operator === 'in') { | |
if (!Array.isArray(w.value)) { | |
throw new BetterAuthError( | |
`The value for the field "${w.field}" must be an array when using the "in" operator.`, | |
); | |
} | |
return [inArray(schemaModel[field], w.value)]; | |
} | |
if (w.operator === 'contains') { | |
return [like(schemaModel[field], `%${w.value}%`)]; | |
} | |
if (w.operator === 'starts_with') { | |
return [like(schemaModel[field], `${w.value}%`)]; | |
} | |
if (w.operator === 'ends_with') { | |
return [like(schemaModel[field], `%${w.value}`)]; | |
} | |
if (w.operator === 'lt') { | |
return [lt(schemaModel[field], w.value)]; | |
} | |
if (w.operator === 'lte') { | |
return [lte(schemaModel[field], w.value)]; | |
} | |
if (w.operator === 'ne') { | |
return [ne(schemaModel[field], w.value)]; | |
} | |
if (w.operator === 'gt') { | |
return [gt(schemaModel[field], w.value)]; | |
} | |
if (w.operator === 'gte') { | |
return [gte(schemaModel[field], w.value)]; | |
} | |
return [eq(schemaModel[field], w.value)]; | |
} | |
const andGroup = where.filter( | |
(w) => w.connector === 'AND' || !w.connector, | |
); | |
const orGroup = where.filter((w) => w.connector === 'OR'); | |
const andClause = and( | |
...andGroup.map((w) => { | |
const field = getFieldName({ model, field: w.field }); | |
if (w.operator === 'in') { | |
if (!Array.isArray(w.value)) { | |
throw new BetterAuthError( | |
`The value for the field "${w.field}" must be an array when using the "in" operator.`, | |
); | |
} | |
return inArray(schemaModel[field], w.value); | |
} | |
return eq(schemaModel[field], w.value); | |
}), | |
); | |
const orClause = or( | |
...orGroup.map((w) => { | |
const field = getFieldName({ model, field: w.field }); | |
return eq(schemaModel[field], w.value); | |
}), | |
); | |
const clause: SQL<unknown>[] = []; | |
if (andGroup.length) clause.push(andClause!); | |
if (orGroup.length) clause.push(orClause!); | |
return clause; | |
} | |
function checkMissingFields( | |
schema: Record<string, any>, | |
model: string, | |
values: Record<string, any>, | |
) { | |
if (!schema) { | |
throw new BetterAuthError( | |
'Drizzle adapter failed to initialize. Schema not found. Please provide a schema object in the adapter options object.', | |
); | |
} | |
for (const key in values) { | |
if (!schema[key]) { | |
throw new BetterAuthError( | |
`The field "${key}" does not exist in the "${model}" schema. Please update your drizzle schema or re-generate using "npx @better-auth/cli generate".`, | |
); | |
} | |
} | |
} | |
return { | |
async create({ model, data: values }) { | |
const schemaModel = getSchema(model); | |
// If generateIds is false, exclude id from values | |
// to let the schema's default function handle ID generation | |
const finalValues = | |
config.generateIds === false && 'id' in schemaModel | |
? Object.fromEntries( | |
Object.entries(values).filter(([key]) => key !== 'id'), | |
) | |
: { ...values }; | |
checkMissingFields(schemaModel, model, finalValues); | |
const builder = db.insert(schemaModel).values(finalValues); | |
const returned = await withReturning(model, builder, finalValues); | |
return returned; | |
}, | |
async findOne({ model, where }) { | |
const schemaModel = getSchema(model); | |
const clause = convertWhereClause(where, model); | |
const res = await db | |
.select() | |
.from(schemaModel) | |
.where(...clause); | |
if (!res.length) return null; | |
return res[0]; | |
}, | |
async findMany({ model, where, sortBy, limit, offset }) { | |
const schemaModel = getSchema(model); | |
const clause = where ? convertWhereClause(where, model) : []; | |
const sortFn = sortBy?.direction === 'desc' ? desc : asc; | |
const builder = db | |
.select() | |
.from(schemaModel) | |
.limit(limit || 100) | |
.offset(offset || 0); | |
if (sortBy?.field) { | |
builder.orderBy( | |
sortFn( | |
schemaModel[getFieldName({ model, field: sortBy?.field })], | |
), | |
); | |
} | |
return (await builder.where(...clause)) as any[]; | |
}, | |
async count({ model, where }) { | |
const schemaModel = getSchema(model); | |
const clause = where ? convertWhereClause(where, model) : []; | |
const res = await db | |
.select({ count: count() }) | |
.from(schemaModel) | |
.where(...clause); | |
return res[0].count; | |
}, | |
async update({ model, where, update: values }) { | |
const schemaModel = getSchema(model); | |
const clause = convertWhereClause(where, model); | |
const builder = db | |
.update(schemaModel) | |
.set(values) | |
.where(...clause); | |
return await withReturning(model, builder, values as any, where); | |
}, | |
async updateMany({ model, where, update: values }) { | |
const schemaModel = getSchema(model); | |
const clause = convertWhereClause(where, model); | |
const builder = db | |
.update(schemaModel) | |
.set(values) | |
.where(...clause); | |
return await builder; | |
}, | |
async delete({ model, where }) { | |
const schemaModel = getSchema(model); | |
const clause = convertWhereClause(where, model); | |
const builder = db.delete(schemaModel).where(...clause); | |
return await builder; | |
}, | |
async deleteMany({ model, where }) { | |
const schemaModel = getSchema(model); | |
const clause = convertWhereClause(where, model); | |
const builder = db.delete(schemaModel).where(...clause); | |
return await builder; | |
}, | |
options: config, | |
}; | |
}, | |
}); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment