Skip to content

Instantly share code, notes, and snippets.

@chrisui
Created July 18, 2025 13:20
Show Gist options
  • Save chrisui/b700330efa63e55a1cab197fc38cf5f7 to your computer and use it in GitHub Desktop.
Save chrisui/b700330efa63e55a1cab197fc38cf5f7 to your computer and use it in GitHub Desktop.
Better-auth Drizzle adapter that uses schema id's
/**
* This module is a copy and extension of the better-auth drizzle-adapter which
* allows for generation of id's to be handled by the schema and/or db
*
* @see https://github.com/better-auth/better-auth/blob/main/packages/better-auth/src/adapters/drizzle-adapter/drizzle-adapter.ts
* @see https://github.com/better-auth/better-auth/issues/3445
* @see https://github.com/better-auth/better-auth/issues/2480
*/
/** biome-ignore-all assist/source/organizeImports: tmp, copied from 3rd-party */
/** biome-ignore-all lint/style/useImportType: tmp, copied from 3rd-party */
/** biome-ignore-all lint/style/noUselessElse: tmp, copied from 3rd-party */
/** biome-ignore-all lint/correctness/noUnusedFunctionParameters: tmp, copied from 3rd-party */
/** biome-ignore-all lint/suspicious/noExplicitAny: tmp, copied from 3rd-party */
/** biome-ignore-all lint/style/noNonNullAssertion: tmp, copied from 3rd-party */
/** biome-ignore-all lint/complexity/useOptionalChain: tmp, copied from 3rd-party */
import {
and,
asc,
count,
desc,
eq,
gt,
gte,
inArray,
like,
lt,
lte,
ne,
or,
sql,
SQL,
} from 'drizzle-orm';
import { BetterAuthError } from 'better-auth';
import type { Where } from 'better-auth/types';
import { createAdapter, type AdapterDebugLogs } from 'better-auth/adapters';
export interface DB {
[key: string]: any;
}
export interface DrizzleAdapterConfig {
/**
* The schema object that defines the tables and fields
*/
schema?: Record<string, any>;
/**
* The database provider
*/
provider: 'pg' | 'mysql' | 'sqlite';
/**
* If the table names in the schema are plural
* set this to true. For example, if the schema
* has an object with a key "users" instead of "user"
*/
usePlural?: boolean;
/**
* Enable debug logs for the adapter
*
* @default false
*/
debugLogs?: AdapterDebugLogs;
/**
* Whether to generate IDs for new records
* Set to false when using schema-level ID generation (e.g., .$defaultFn(() => 'randomid'))
*
* @default true
*/
generateIds?: boolean;
}
export const drizzleAdapter = (db: DB, config: DrizzleAdapterConfig) =>
createAdapter({
config: {
adapterId: 'drizzle',
adapterName: 'Drizzle Adapter',
usePlural: config.usePlural ?? false,
debugLogs: config.debugLogs ?? false,
},
adapter: ({ getFieldName, debugLog }) => {
function getSchema(model: string) {
const schema = config.schema || db._.fullSchema;
if (!schema) {
throw new BetterAuthError(
'Drizzle adapter failed to initialize. Schema not found. Please provide a schema object in the adapter options object.',
);
}
const schemaModel = schema[model];
if (!schemaModel) {
throw new BetterAuthError(
`[# Drizzle Adapter]: The model "${model}" was not found in the schema object. Please pass the schema directly to the adapter options.`,
);
}
return schemaModel;
}
const withReturning = async (
model: string,
builder: any,
data: Record<string, any>,
where?: Where[],
) => {
if (config.provider !== 'mysql') {
const c = await builder.returning();
return c[0];
}
await builder.execute();
const schemaModel = getSchema(model);
const builderVal = builder.config?.values;
if (where?.length) {
const clause = convertWhereClause(where, model);
const res = await db
.select()
.from(schemaModel)
.where(...clause);
return res[0];
} else if (builderVal && builderVal[0]?.id?.value) {
let tId = builderVal[0]?.id?.value;
if (!tId) {
//get last inserted id
const lastInsertId = await db
.select({ id: sql`LAST_INSERT_ID()` })
.from(schemaModel)
.orderBy(desc(schemaModel.id))
.limit(1);
tId = lastInsertId[0].id;
}
const res = await db
.select()
.from(schemaModel)
.where(eq(schemaModel.id, tId))
.limit(1)
.execute();
return res[0];
} else if (data.id) {
const res = await db
.select()
.from(schemaModel)
.where(eq(schemaModel.id, data.id))
.limit(1)
.execute();
return res[0];
} else {
// If the user doesn't have `id` as a field, then this will fail.
// We expect that they defined `id` in all of their models.
if (!('id' in schemaModel)) {
throw new BetterAuthError(
`The model "${model}" does not have an "id" field. Please use the "id" field as your primary key.`,
);
}
const res = await db
.select()
.from(schemaModel)
.orderBy(desc(schemaModel.id))
.limit(1)
.execute();
return res[0];
}
};
function convertWhereClause(where: Where[], model: string) {
const schemaModel = getSchema(model);
if (!where) return [];
if (where.length === 1) {
const w = where[0];
if (!w) {
return [];
}
const field = getFieldName({ model, field: w.field });
if (!schemaModel[field]) {
throw new BetterAuthError(
`The field "${w.field}" does not exist in the schema for the model "${model}". Please update your schema.`,
);
}
if (w.operator === 'in') {
if (!Array.isArray(w.value)) {
throw new BetterAuthError(
`The value for the field "${w.field}" must be an array when using the "in" operator.`,
);
}
return [inArray(schemaModel[field], w.value)];
}
if (w.operator === 'contains') {
return [like(schemaModel[field], `%${w.value}%`)];
}
if (w.operator === 'starts_with') {
return [like(schemaModel[field], `${w.value}%`)];
}
if (w.operator === 'ends_with') {
return [like(schemaModel[field], `%${w.value}`)];
}
if (w.operator === 'lt') {
return [lt(schemaModel[field], w.value)];
}
if (w.operator === 'lte') {
return [lte(schemaModel[field], w.value)];
}
if (w.operator === 'ne') {
return [ne(schemaModel[field], w.value)];
}
if (w.operator === 'gt') {
return [gt(schemaModel[field], w.value)];
}
if (w.operator === 'gte') {
return [gte(schemaModel[field], w.value)];
}
return [eq(schemaModel[field], w.value)];
}
const andGroup = where.filter(
(w) => w.connector === 'AND' || !w.connector,
);
const orGroup = where.filter((w) => w.connector === 'OR');
const andClause = and(
...andGroup.map((w) => {
const field = getFieldName({ model, field: w.field });
if (w.operator === 'in') {
if (!Array.isArray(w.value)) {
throw new BetterAuthError(
`The value for the field "${w.field}" must be an array when using the "in" operator.`,
);
}
return inArray(schemaModel[field], w.value);
}
return eq(schemaModel[field], w.value);
}),
);
const orClause = or(
...orGroup.map((w) => {
const field = getFieldName({ model, field: w.field });
return eq(schemaModel[field], w.value);
}),
);
const clause: SQL<unknown>[] = [];
if (andGroup.length) clause.push(andClause!);
if (orGroup.length) clause.push(orClause!);
return clause;
}
function checkMissingFields(
schema: Record<string, any>,
model: string,
values: Record<string, any>,
) {
if (!schema) {
throw new BetterAuthError(
'Drizzle adapter failed to initialize. Schema not found. Please provide a schema object in the adapter options object.',
);
}
for (const key in values) {
if (!schema[key]) {
throw new BetterAuthError(
`The field "${key}" does not exist in the "${model}" schema. Please update your drizzle schema or re-generate using "npx @better-auth/cli generate".`,
);
}
}
}
return {
async create({ model, data: values }) {
const schemaModel = getSchema(model);
// If generateIds is false, exclude id from values
// to let the schema's default function handle ID generation
const finalValues =
config.generateIds === false && 'id' in schemaModel
? Object.fromEntries(
Object.entries(values).filter(([key]) => key !== 'id'),
)
: { ...values };
checkMissingFields(schemaModel, model, finalValues);
const builder = db.insert(schemaModel).values(finalValues);
const returned = await withReturning(model, builder, finalValues);
return returned;
},
async findOne({ model, where }) {
const schemaModel = getSchema(model);
const clause = convertWhereClause(where, model);
const res = await db
.select()
.from(schemaModel)
.where(...clause);
if (!res.length) return null;
return res[0];
},
async findMany({ model, where, sortBy, limit, offset }) {
const schemaModel = getSchema(model);
const clause = where ? convertWhereClause(where, model) : [];
const sortFn = sortBy?.direction === 'desc' ? desc : asc;
const builder = db
.select()
.from(schemaModel)
.limit(limit || 100)
.offset(offset || 0);
if (sortBy?.field) {
builder.orderBy(
sortFn(
schemaModel[getFieldName({ model, field: sortBy?.field })],
),
);
}
return (await builder.where(...clause)) as any[];
},
async count({ model, where }) {
const schemaModel = getSchema(model);
const clause = where ? convertWhereClause(where, model) : [];
const res = await db
.select({ count: count() })
.from(schemaModel)
.where(...clause);
return res[0].count;
},
async update({ model, where, update: values }) {
const schemaModel = getSchema(model);
const clause = convertWhereClause(where, model);
const builder = db
.update(schemaModel)
.set(values)
.where(...clause);
return await withReturning(model, builder, values as any, where);
},
async updateMany({ model, where, update: values }) {
const schemaModel = getSchema(model);
const clause = convertWhereClause(where, model);
const builder = db
.update(schemaModel)
.set(values)
.where(...clause);
return await builder;
},
async delete({ model, where }) {
const schemaModel = getSchema(model);
const clause = convertWhereClause(where, model);
const builder = db.delete(schemaModel).where(...clause);
return await builder;
},
async deleteMany({ model, where }) {
const schemaModel = getSchema(model);
const clause = convertWhereClause(where, model);
const builder = db.delete(schemaModel).where(...clause);
return await builder;
},
options: config,
};
},
});
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment