Last active
December 2, 2021 13:19
-
-
Save supplient/9e606ea1790c2e5b775609f7a69dc5a3 to your computer and use it in GitHub Desktop.
一个基于显式欧拉法的简单布料模拟,基于mass-spring模型,包含张力、平面内剪切力、平面外弯曲力、风力
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import taichi as ti | |
import numpy as np | |
from random import random | |
from tqdm import tqdm | |
from math import sin, pi, sqrt | |
from timer import Timer | |
timer = Timer() | |
################ | |
# Support Functions | |
################ | |
################ | |
# Init taichi | |
################ | |
ti.init(arch=ti.gpu) | |
################ | |
# Input data | |
################ | |
xn = 20 | |
yn = xn | |
gravityConstant = 9.8 | |
timestep = 0.001 | |
m = 0.0008 # kg | |
windForce = [2.4, 0, -2.4] | |
lConst = 2 / xn | |
lCutConst = lConst * sqrt(2) | |
lBendConst = lConst * 2 | |
# TODO adjust parameters | |
k_tension = 10 | |
d_tension = 2 | |
k_cut = 1 | |
d_cut = 0.2 | |
k_bend = 0.1 | |
d_bend = 0.02 | |
k_drag = 1 | |
k_rise = 1 | |
################ | |
# Taichi Tensors | |
################ | |
x = ti.Matrix(3, dt=ti.f32, shape=(xn, yn)) | |
v = ti.Matrix(3, dt=ti.f32, shape=(xn, yn)) | |
x_next = ti.Matrix(3, dt=ti.f32, shape=(xn, yn)) | |
v_next = ti.Matrix(3, dt=ti.f32, shape=(xn, yn)) | |
w = ti.Matrix(1, dt=ti.f32, shape=(xn, yn)) | |
F = ti.Matrix(3, dt=ti.f32, shape=(xn, yn)) | |
x_show = ti.Matrix(2, dt=ti.f32, shape=(xn*yn)) | |
color_show = ti.var(ti.i32, shape=(xn*yn)) | |
ver_line_begin_show = ti.Matrix(2, dt=ti.f32, shape=((xn-1)*yn)) | |
ver_line_end_show = ti.Matrix(2, dt=ti.f32, shape=((xn-1)*yn)) | |
ver_line_color_show = ti.var(ti.i32, shape=((xn-1)*yn)) | |
hor_line_begin_show = ti.Matrix(2, dt=ti.f32, shape=(xn*(yn-1))) | |
hor_line_end_show = ti.Matrix(2, dt=ti.f32, shape=(xn*(yn-1))) | |
hor_line_color_show = ti.var(ti.i32, shape=(xn*(yn-1))) | |
################ | |
# Support Functions in ti space | |
################ | |
@ti.func | |
def Distance(a, b): | |
res = 0 | |
res += (a[0]-b[0]) * (a[0]-b[0]) | |
res += (a[1]-b[1]) * (a[1]-b[1]) | |
res += (a[2]-b[2]) * (a[2]-b[2]) | |
return ti.sqrt(res) | |
@ti.func | |
def Length(a): | |
res = 0.0 | |
res += a[0] * a[0] | |
res += a[1] * a[1] | |
res += a[2] * a[2] | |
return ti.sqrt(res) | |
@ti.func | |
def Normalize(a): | |
l = Length(a) | |
a[0] /= l | |
a[1] /= l | |
a[2] /= l | |
@ti.func | |
def DotProd(a, b): | |
return a[0]*b[0] + a[1]*b[1] + a[2]*b[2] | |
@ti.func | |
def CrossProd(a, b): | |
return a.cross(b) | |
@ti.func | |
def world2screen(x, y): | |
L = 1.0 | |
x = (x-(-L))/(2*L) | |
y = (y-(-L))/(2*L) | |
return (x,y) | |
@ti.func | |
def depth2color(d): | |
maxL = 0.5 | |
minL = -1.8 | |
d = (d-minL)/(maxL-minL) | |
d = min(d, 1.0) | |
d = max(d, 0) | |
R = int(0xFF * d) | |
G = int(0xFF * d) | |
B = int(0xFF * d) | |
R *= 256 * 256 | |
G *= 256 | |
return R + G + B | |
################ | |
# Kernels | |
################ | |
@ti.kernel | |
def InitX(): | |
for i,j in x: | |
x[i,j] = [ | |
(i-(xn-1)/2) * lConst, | |
(j-(yn-1)/2) * lConst, | |
0 | |
] | |
@ti.kernel | |
def InitV(): | |
for i,j in v: | |
v[i,j] = [0, 0, 0] | |
@ti.kernel | |
def InitW(): | |
for i,j in w: | |
if i==0: | |
w[i,j] = [0] | |
else: | |
w[i,j] = [1/m] | |
@ti.func | |
def CalForceFromNeighbor(i, j, a, b, k, d, l0): | |
if a>=0 and a<xn and b>=0 and b<yn: | |
xDelta = x[a,b] - x[i,j] | |
xDeltaLen = Length(xDelta) | |
Normalize(xDelta) | |
vDelta = v[a,b] - v[i,j] | |
F[i,j] += k * (xDeltaLen - l0) * xDelta # 弹簧拉力 | |
F[i,j] += d * DotProd(xDelta, vDelta) * xDelta # 弹簧阻尼 | |
@ti.func | |
def CalWindForceFromNeighbor(i, j, a, b): | |
if a>=0 and a<xn-1 and b>=0 and b<yn-1: | |
v_quad = v[a,b] + v[a+1,b] + v[a,b+1] + v[a+1,b+1] | |
v_quad /= 4 | |
v_rel = v_quad - windForce | |
norm = CrossProd(x[a,b]-x[a+1,b], x[a,b]-x[a,b+1]) | |
Normalize(norm) | |
S = Length(CrossProd(x[a,b]-x[a+1,b], x[a,b]-x[a,b+1])) | |
S += Length(CrossProd(x[a+1,b+1]-x[a+1,b], x[a+1,b+1]-x[a,b+1])) | |
S /= 2 | |
S *= ti.abs(DotProd(norm, v_rel)) | |
F[i,j] += - 0.25 * S * k_drag * v_rel | |
# F[i,j] += - 0.25 * S * k_rise * CrossProd(v_rel, CrossProd(norm, v_rel).normalized()) | |
@ti.kernel | |
def CalF(): | |
for i,j in F: | |
F[i,j] = [0,0,0] | |
# 张力 | |
CalForceFromNeighbor(i, j, i-1, j, k_tension, d_tension, lConst) | |
CalForceFromNeighbor(i, j, i+1, j, k_tension, d_tension, lConst) | |
CalForceFromNeighbor(i, j, i, j-1, k_tension, d_tension, lConst) | |
CalForceFromNeighbor(i, j, i, j+1, k_tension, d_tension, lConst) | |
# 平面内剪切力 | |
CalForceFromNeighbor(i, j, i-1, j-1, k_cut, d_cut, lCutConst) | |
CalForceFromNeighbor(i, j, i+1, j-1, k_cut, d_cut, lCutConst) | |
CalForceFromNeighbor(i, j, i-1, j+1, k_cut, d_cut, lCutConst) | |
CalForceFromNeighbor(i, j, i+1, j+1, k_cut, d_cut, lCutConst) | |
# 平面外弯曲力 | |
CalForceFromNeighbor(i, j, i-2, j, k_bend, d_bend, lBendConst) | |
CalForceFromNeighbor(i, j, i+2, j, k_bend, d_bend, lBendConst) | |
CalForceFromNeighbor(i, j, i, j-2, k_bend, d_bend, lBendConst) | |
CalForceFromNeighbor(i, j, i, j+2, k_bend, d_bend, lBendConst) | |
# 重力 | |
# F[i, j] += [0.0, -m*gravityConstant, 0.0] | |
# 风力 | |
CalWindForceFromNeighbor(i, j, i-1, j-1) | |
CalWindForceFromNeighbor(i, j, i-1, j ) | |
CalWindForceFromNeighbor(i, j, i , j-1) | |
CalWindForceFromNeighbor(i, j, i , j ) | |
@ti.kernel | |
def UpdateV(): | |
for i,j in v_next: | |
v_next[i,j] = v[i,j] + w[i,j][0] * F[i,j] * timestep | |
@ti.kernel | |
def UpdateX(): | |
for i,j in x_next: | |
x_next[i,j] = x[i,j] + v_next[i,j] * timestep | |
################ | |
# Time Loop Functions | |
################ | |
def Update(): | |
UpdateV() | |
UpdateX() | |
@ti.kernel | |
def SwitchXV(): | |
for i,j in x: | |
x[i,j] = x_next[i,j] | |
v[i,j] = v_next[i,j] | |
@ti.func | |
def toInd(i, j, nj=yn): | |
return i*nj+j | |
@ti.kernel | |
def PrepareShow(): | |
for i,j in x: | |
now = world2screen(x[i,j][0], x[i,j][1]) | |
nexti = world2screen(x[i+1,j][0], x[i+1,j][1]) | |
nextj = world2screen(x[i,j+1][0], x[i,j+1][1]) | |
color = depth2color(x[i,j][2]) | |
x_show[toInd(i,j)] = now | |
color_show[toInd(i,j)] = color | |
if i < xn-1: | |
ver_line_begin_show[toInd(i,j)] = now | |
ver_line_end_show [toInd(i,j)] = nexti | |
ver_line_color_show[toInd(i,j)] = color | |
if j < yn-1: | |
hor_line_begin_show[toInd(i,j,yn-1)] = now | |
hor_line_end_show [toInd(i,j,yn-1)] = nextj | |
hor_line_color_show[toInd(i,j,yn-1)] = color | |
def TimeTick(t): | |
timer.Tick() | |
CalF() | |
timer.TickAndLog("CalF: ") | |
Update() | |
timer.TickAndLog("Update: ") | |
SwitchXV() | |
timer.TickAndLog("SwitchXV: ") | |
################ | |
# Init tensors | |
################ | |
InitX() | |
InitV() | |
InitW() | |
################ | |
# GUI Settings | |
################ | |
videoManager = ti.VideoManager(output_dir="./video", framerate=24, automatic_build=False) | |
gui = ti.GUI(background_color=0x222222) | |
colors = [0xFF0000, 0x00FFFF, 0x0000FF, 0xFFFF00, 0xFF00FF, 0x00FF00] | |
show_on_screen = True | |
seconds = 10 | |
fps = 24 | |
if show_on_screen: | |
fps = 60 | |
timestepsPerFrame = int(1.0/fps/timestep) | |
totalFrame = seconds*fps | |
if show_on_screen: | |
totalFrame = 1000000 | |
################ | |
# Main Loop | |
################ | |
timer.Init() | |
for frameCount in tqdm(range(totalFrame)): | |
# print(x) | |
for t in range(timestepsPerFrame): | |
TimeTick(timestep*t + frameCount/fps) | |
timer.Tick() | |
PrepareShow() | |
timer.TickAndLog("PrepareShow: ") | |
gui.circles(x_show.to_numpy(), color_show.to_numpy(), 3) | |
gui.lines(ver_line_begin_show.to_numpy(), ver_line_end_show.to_numpy(), 1, ver_line_color_show.to_numpy()) | |
gui.lines(hor_line_begin_show.to_numpy(), hor_line_end_show.to_numpy(), 1, hor_line_color_show.to_numpy()) | |
timer.TickAndLog("Draw: ") | |
if show_on_screen: | |
gui.show() | |
else: | |
img = gui.get_image() | |
videoManager.write_frame(img) | |
gui.clear() | |
if not show_on_screen: | |
videoManager.make_video(gif=True, mp4=False) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
class Timer: | |
def __init__(self): | |
self.tLast = 0 | |
self.doLog = False | |
def Init(self): | |
self.tLast = time.perf_counter() | |
def Tick(self): | |
now = time.perf_counter() | |
res = now - self.tLast | |
self.tLast = now | |
return res * 1000 # ms | |
def TickAndLog(self, pre_s): | |
timeDelta = self.Tick() | |
if self.doLog: | |
print(pre_s, timeDelta) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment